Source code for merton.greeks.autodiff

"""Autodiff Greeks via JAX.

Each Greek is computed by composing :func:`jax.grad` with the BSM equity
or PD expression, then :func:`jax.vmap` to vectorise across input arrays
and :func:`jax.jit` to compile the result. This is useful for:

- Validating the closed-form Greeks in :mod:`merton.greeks.equity` and
  :mod:`merton.greeks.pd_sensitivity` (they should agree to ~1e-7).
- Differentiating *novel* BSM extensions where deriving Greeks by hand is
  error-prone.

The module is only importable when ``merton[jax]`` is installed; importing
it without JAX raises a clean ``ImportError``.

Examples
--------
>>> from merton.greeks import autodiff  # doctest: +SKIP
>>> import numpy as np  # doctest: +SKIP
>>> A = np.array([100.0, 150.0, 80.0])  # doctest: +SKIP
>>> autodiff.equity_delta_ad(A, 0.25, 60.0, 0.04, 1.0)  # doctest: +SKIP
"""

from __future__ import annotations

try:
    import jax
    import jax.numpy as jnp
    from jax.scipy.stats import norm as _jnorm
except ImportError as _err:  # pragma: no cover - optional dep
    raise ImportError(
        'merton.greeks.autodiff requires the JAX extra: `pip install "merton[jax]"`.'
    ) from _err

# JAX defaults to float32; structural-credit Greeks need float64 to match
# the closed-form implementations to single-precision tolerance. Enable
# x64 once at import time. (Setting it later via jax.config.update would
# warn that some jit caches are stale.)
jax.config.update("jax_enable_x64", True)


def _equity_value(A, sigma, D, r, T, q):  # type: ignore[no-untyped-def]
    sqrtT = jnp.sqrt(T)
    d1 = (jnp.log(A / D) + (r - q + 0.5 * sigma * sigma) * T) / (sigma * sqrtT)
    d2 = d1 - sigma * sqrtT
    return A * jnp.exp(-q * T) * _jnorm.cdf(d1) - D * jnp.exp(-r * T) * _jnorm.cdf(d2)


def _dd(A, sigma, D, r, T, q):  # type: ignore[no-untyped-def]
    sqrtT = jnp.sqrt(T)
    return (jnp.log(A / D) + (r - q - 0.5 * sigma * sigma) * T) / (sigma * sqrtT)


def _pd(A, sigma, D, r, T, q):  # type: ignore[no-untyped-def]
    return _jnorm.cdf(-_dd(A, sigma, D, r, T, q))


def _grad_then_vmap_then_jit(fn, *, argnums: int, in_axes):  # type: ignore[no-untyped-def]
    return jax.jit(jax.vmap(jax.grad(fn, argnums=argnums), in_axes=in_axes))


# Greek table: (name, function, argnum_of_x_to_diff_against, default in_axes)
# in_axes carries 0 for asset_value/asset_vol/debt (broadcastable arrays) and
# None for rf/T/q (scalars).
_VMAP_AXES_FULL = (0, 0, 0, None, None, None)
_VMAP_AXES_AVOL = (0, 0, 0, None, None, None)


_equity_delta_fn = _grad_then_vmap_then_jit(_equity_value, argnums=0, in_axes=_VMAP_AXES_FULL)
_equity_vega_fn = _grad_then_vmap_then_jit(_equity_value, argnums=1, in_axes=_VMAP_AXES_FULL)
_equity_rho_fn = _grad_then_vmap_then_jit(_equity_value, argnums=3, in_axes=_VMAP_AXES_FULL)
_equity_theta_fn = _grad_then_vmap_then_jit(_equity_value, argnums=4, in_axes=_VMAP_AXES_FULL)

# Gamma = ∂²E/∂A² = ∂/∂A (∂E/∂A); compose grad twice.
_equity_gamma_fn = jax.jit(
    jax.vmap(jax.grad(jax.grad(_equity_value, argnums=0), argnums=0), in_axes=_VMAP_AXES_FULL)
)

_pd_dleverage_fn = _grad_then_vmap_then_jit(_pd, argnums=2, in_axes=_VMAP_AXES_FULL)
_pd_dvol_fn = _grad_then_vmap_then_jit(_pd, argnums=1, in_axes=_VMAP_AXES_FULL)
_pd_drate_fn = _grad_then_vmap_then_jit(_pd, argnums=3, in_axes=_VMAP_AXES_FULL)


def _broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield):  # type: ignore[no-untyped-def]
    """Promote inputs to 1-D JAX arrays of the same length."""
    arrays = [
        jnp.atleast_1d(jnp.asarray(a, dtype=jnp.float64)) for a in (asset_value, asset_vol, debt)
    ]
    n = max(a.shape[0] for a in arrays)
    arrays = [jnp.broadcast_to(a, (n,)) for a in arrays]
    return (*arrays, jnp.float64(rf), jnp.float64(T), jnp.float64(dividend_yield))


[docs] def equity_delta_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂E/∂A via JAX autodiff.""" return _equity_delta_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def equity_vega_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂E/∂σ_A via JAX autodiff.""" return _equity_vega_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def equity_gamma_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂²E/∂A² via JAX autodiff (gradient of the delta).""" return _equity_gamma_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def equity_theta_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] r"""Theta as ``∂E/∂t`` (i.e. ``-∂E/∂T``) — matches the option-pricing convention used by :func:`merton.greeks.equity_theta` so the autodiff and closed-form columns line up sign-wise.""" return -_equity_theta_fn( *_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield) )
[docs] def equity_rho_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂E/∂r via JAX autodiff.""" return _equity_rho_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def pd_leverage_sensitivity_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂PD/∂D via JAX autodiff.""" return _pd_dleverage_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def pd_vol_sensitivity_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂PD/∂σ_A via JAX autodiff.""" return _pd_dvol_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
[docs] def pd_rate_sensitivity_ad(asset_value, asset_vol, debt, rf, T, *, dividend_yield=0.0): # type: ignore[no-untyped-def] """∂PD/∂r via JAX autodiff.""" return _pd_drate_fn(*_broadcast_inputs(asset_value, asset_vol, debt, rf, T, dividend_yield))
__all__ = [ "equity_delta_ad", "equity_gamma_ad", "equity_rho_ad", "equity_theta_ad", "equity_vega_ad", "pd_leverage_sensitivity_ad", "pd_rate_sensitivity_ad", "pd_vol_sensitivity_ad", ]