58. The Income Fluctuation Problem II: Optimistic Policy Iteration#

58.1. Overview#

In The Income Fluctuation Problem I: Discretization and VFI we studied the income fluctuation problem and solved it using value function iteration (VFI).

In this lecture we’ll solve the same problem using optimistic policy iteration (OPI), which is very general, typically faster than VFI and only slightly more complex.

OPI combines elements of both value function iteration and policy iteration.

A detailed discussion of the algorithm can be found in DP1.

Here our aim is to implement OPI and test whether or not it yields significant speed improvements over standard VFI for the income fluctuation problem.

In addition to Anaconda, this lecture will need the following libraries:

!pip install quantecon jax

Hide code cell output

Requirement already satisfied: quantecon in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.10.1)
Requirement already satisfied: jax in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.8.1)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.61.0)
Requirement already satisfied: numpy>=1.17.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.1.3)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: scipy>=1.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.15.3)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.13.3)
Requirement already satisfied: jaxlib<=0.8.1,>=0.8.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.8.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.5.4)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: llvmlite<0.45,>=0.44.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.44.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.4.26)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)

We will use the following imports:

import quantecon as qe
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from typing import NamedTuple
from time import time

58.2. Model and Primitives#

The model and parameters are the same as in The Income Fluctuation Problem I: Discretization and VFI.

We repeat the key elements here for convenience.

The household’s problem is to maximize

\[ \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) \]

subject to

\[ a_{t+1} + c_t \leq R a_t + y_t \]

where \(u(c) = c^{1-\gamma}/(1-\gamma)\).

Here’s the model structure:

class Model(NamedTuple):
    β: float              # Discount factor
    R: float              # Gross interest rate
    γ: float              # CRRA parameter
    a_grid: jnp.ndarray   # Asset grid
    y_grid: jnp.ndarray   # Income grid
    Q: jnp.ndarray        # Markov matrix for income


def create_consumption_model(
        R=1.01,                    # Gross interest rate
        β=0.98,                    # Discount factor
        γ=2,                       # CRRA parameter
        a_min=0.01,                # Min assets
        a_max=5.0,                 # Max assets
        a_size=150,                # Grid size
        ρ=0.9, ν=0.1, y_size=100   # Income parameters
    ):
    """
    Creates an instance of the consumption-savings model.

    """
    a_grid = jnp.linspace(a_min, a_max, a_size)
    mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
    y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P)
    return Model(β, R, γ, a_grid, y_grid, Q)

58.3. Operators and Policies#

We repeat some functions from The Income Fluctuation Problem I: Discretization and VFI.

Here is the right hand side of the Bellman equation:

@jax.jit
def B(v, model):
    """
    A vectorized version of the right-hand side of the Bellman equation
    (before maximization), which is a 3D array representing

        B(a, y, a′) = u(Ra + y - a′) + β Σ_y′ v(a′, y′) Q(y, y′)

    for all (a, y, a′).
    """

    # Unpack
    β, R, γ, a_grid, y_grid, Q = model
    a_size, y_size = len(a_grid), len(y_grid)

    # Compute current rewards r(a, y, ap) as array r[i, j, ip]
    a  = jnp.reshape(a_grid, (a_size, 1, 1))    # a[i]   ->  a[i, j, ip]
    y  = jnp.reshape(y_grid, (1, y_size, 1))    # z[j]   ->  z[i, j, ip]
    ap = jnp.reshape(a_grid, (1, 1, a_size))    # ap[ip] -> ap[i, j, ip]
    c = R * a + y - ap

    # Calculate continuation rewards at all combinations of (a, y, ap)
    v = jnp.reshape(v, (1, 1, a_size, y_size))  # v[ip, jp] -> v[i, j, ip, jp]
    Q = jnp.reshape(Q, (1, y_size, 1, y_size))  # Q[j, jp]  -> Q[i, j, ip, jp]
    EV = jnp.sum(v * Q, axis=3)                 # sum over last index jp

    # Compute the right-hand side of the Bellman equation
    return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)

Here’s the Bellman operator:

@jax.jit
def T(v, model):
    "The Bellman operator."
    return jnp.max(B(v, model), axis=2)

Here’s the function that computes a \(v\)-greedy policy:

@jax.jit
def get_greedy(v, model):
    "Computes a v-greedy policy, returned as a set of indices."
    return jnp.argmax(B(v, model), axis=2)

Now we define the policy operator \(T_\sigma\), which is the Bellman operator with policy \(\sigma\) fixed.

For a given policy \(\sigma\), the policy operator is defined by

\[ (T_\sigma v)(a, y) = u(Ra + y - \sigma(a, y)) + \beta \sum_{y'} v(\sigma(a, y), y') Q(y, y') \]
def T_σ(v, σ, model, i, j):
    """
    The σ-policy operator for indices (i, j) -> (a, y).
    """
    β, R, γ, a_grid, y_grid, Q = model

    # Get values at current state
    a, y = a_grid[i], y_grid[j]
    # Get policy choice
    ap = a_grid[σ[i, j]]

    # Compute current reward
    c = R * a + y - ap
    r = jnp.where(c > 0, c**(1-γ)/(1-γ), -jnp.inf)

    # Compute expected value
    EV = jnp.sum(v[σ[i, j], :] * Q[j, :])

    return r + β * EV

Apply vmap to vectorize:

T_σ_1    = jax.vmap(T_σ,   in_axes=(None, None, None, None, 0))
T_σ_vmap = jax.vmap(T_σ_1, in_axes=(None, None, None, 0,    None))

@jax.jit
def T_σ_vec(v, σ, model):
    """Vectorized version of T_σ."""
    a_size, y_size = len(model.a_grid), len(model.y_grid)
    a_indices = jnp.arange(a_size)
    y_indices = jnp.arange(y_size)
    return T_σ_vmap(v, σ, model, a_indices, y_indices)

Now we need a function to apply the policy operator m times:

@jax.jit
def iterate_policy_operator(σ, v, m, model):
    """
    Apply the policy operator T_σ exactly m times to v.
    """
    def update(i, v):
        return T_σ_vec(v, σ, model)

    v = jax.lax.fori_loop(0, m, update, v)
    return v

58.4. Value Function Iteration#

For comparison, here’s VFI from The Income Fluctuation Problem I: Discretization and VFI:

@jax.jit
def value_function_iteration(model, tol=1e-5, max_iter=10_000):
    """
    Implements VFI using successive approximation.
    """
    def body_fun(k_v_err):
        k, v, error = k_v_err
        v_new = T(v, model)
        error = jnp.max(jnp.abs(v_new - v))
        return k + 1, v_new, error

    def cond_fun(k_v_err):
        k, v, error = k_v_err
        return jnp.logical_and(error > tol, k < max_iter)

    v_init = jnp.zeros((len(model.a_grid), len(model.y_grid)))
    k, v_star, error = jax.lax.while_loop(cond_fun, body_fun,
                                          (1, v_init, tol + 1))
    return v_star, get_greedy(v_star, model)

58.5. Optimistic Policy Iteration#

Now we implement OPI.

The algorithm alternates between

  1. Performing \(m\) policy operator iterations to update the value function

  2. Computing a new greedy policy based on the updated value function

@jax.jit
def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
    """
    Implements optimistic policy iteration with step size m.

    Parameters:
    -----------
    model : Model
        The consumption-savings model
    m : int
        Number of policy operator iterations per step
    tol : float
        Tolerance for convergence
    max_iter : int
        Maximum number of iterations
    """
    v_init = jnp.zeros((len(model.a_grid), len(model.y_grid)))

    def condition_function(inputs):
        i, v, error = inputs
        return jnp.logical_and(error > tol, i < max_iter)

    def update(inputs):
        i, v, error = inputs
        last_v = v
        σ = get_greedy(v, model)
        v = iterate_policy_operator(σ, v, m, model)
        error = jnp.max(jnp.abs(v - last_v))
        i += 1
        return i, v, error

    num_iter, v, error = jax.lax.while_loop(condition_function,
                                            update,
                                            (0, v_init, tol + 1))

    return v, get_greedy(v, model)

58.6. Timing Comparison#

Let’s create a model and compare the performance of VFI and OPI.

model = create_consumption_model()
W1125 05:15:55.823928    2413 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1125 05:15:55.827396    2350 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.

First, let’s time VFI:

print("Starting VFI.")
start = time()
v_star_vfi, σ_star_vfi = value_function_iteration(model)
v_star_vfi.block_until_ready()
vfi_time_with_compile = time() - start
print(f"VFI completed in {vfi_time_with_compile:.2f} seconds.")
Starting VFI.
VFI completed in 0.75 seconds.

Run it again to eliminate compile time:

start = time()
v_star_vfi, σ_star_vfi = value_function_iteration(model)
v_star_vfi.block_until_ready()
vfi_time = time() - start
print(f"VFI completed in {vfi_time:.2f} seconds.")
VFI completed in 0.14 seconds.

Now let’s time OPI with different values of m:

print("Starting OPI with m=10.")
start = time()
v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10)
v_star_opi.block_until_ready()
opi_time_with_compile = time() - start
print(f"OPI completed in {opi_time_with_compile:.2f} seconds.")
Starting OPI with m=10.
OPI completed in 0.47 seconds.

Run it again:

start = time()
v_star_opi, σ_star_opi = optimistic_policy_iteration(model, m=10)
v_star_opi.block_until_ready()
opi_time = time() - start
print(f"OPI completed in {opi_time:.2f} seconds.")
OPI completed in 0.04 seconds.

Check that we get the same result:

print(f"Policies match: {jnp.allclose(σ_star_vfi, σ_star_opi)}")
Policies match: False

Here’s the speedup:

print(f"Speedup factor: {vfi_time / opi_time:.2f}")
Speedup factor: 3.57

Let’s try different values of m to see how it affects performance:

m_vals = [1, 5, 10, 25, 50, 100, 200, 400]
opi_times = []

for m in m_vals:
    start = time()
    v_star, σ_star = optimistic_policy_iteration(model, m=m)
    v_star.block_until_ready()
    elapsed = time() - start
    opi_times.append(elapsed)
    print(f"OPI with m={m:3d} completed in {elapsed:.2f} seconds.")
OPI with m=  1 completed in 0.10 seconds.
OPI with m=  5 completed in 0.05 seconds.
OPI with m= 10 completed in 0.04 seconds.
OPI with m= 25 completed in 0.04 seconds.
OPI with m= 50 completed in 0.04 seconds.
OPI with m=100 completed in 0.04 seconds.
OPI with m=200 completed in 0.06 seconds.
OPI with m=400 completed in 0.11 seconds.

Plot the results:

fig, ax = plt.subplots()
ax.plot(m_vals, opi_times, 'o-', label='OPI')
ax.axhline(vfi_time, linestyle='--', color='red', label='VFI')
ax.set_xlabel('m (policy steps per iteration)')
ax.set_ylabel('time (seconds)')
ax.legend()
ax.set_title('OPI execution time vs step size m')
plt.show()
_images/5c875145e94048534fa10dc6037d3390594b7fbfdcc4e4f0af42b22bc1106196.png

Here’s a summary of the results

  • When \(m=1\), OPI is slight slower than VFI, even though they should be mathematically equivalent, due to small inefficiencies associated with extra function calls.

  • OPI outperforms VFI for a very large range of \(m\) values.

  • For very large \(m\), OPI performance begins to degrade as we spend too much time iterating the policy operator.

58.7. Exercises#

Exercise 58.1

The speed gains achieved by OPI are quite robust to parameter changes.

Confirm this by experimenting with different parameter values for the income process (\(\rho\) and \(\nu\)).

Measure how they affect the relative performance of VFI vs OPI.

Try:

  • \(\rho \in \{0.8, 0.9, 0.95\}\)

  • \(\nu \in \{0.05, 0.1, 0.2\}\)

For each combination, compute the speedup factor (VFI time / OPI time) and report your findings.