Source code for merton.batch.panel

"""``batch_fit`` — calibrate Merton over a panel of firms in parallel."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd

from ..core.firm import Firm
from ..core.model import MertonModel
from ..core.panel import FirmPanel
from .dispatch import parallel_map

if TYPE_CHECKING:
    pass


def _panelify(df: Any, mapping: dict[str, str] | None) -> tuple[FirmPanel, str]:
    """Coerce input to a FirmPanel; return (panel, output_kind)."""
    if isinstance(df, FirmPanel):
        return df, "panel"
    if isinstance(df, pd.DataFrame):
        return FirmPanel.from_pandas(df, mapping=mapping), "pandas"
    try:
        import polars as pl  # type: ignore[import-not-found]
    except ImportError:
        pl = None
    if pl is not None and isinstance(df, (pl.DataFrame, pl.LazyFrame)):
        return FirmPanel.from_polars(df, mapping=mapping), "polars"
    import pyarrow as pa

    if isinstance(df, pa.Table):
        return FirmPanel.from_arrow(df), "arrow"
    raise TypeError(f"batch_fit cannot accept {type(df).__name__}")


def _format_output(rows: list[dict[str, Any]], kind: str) -> Any:
    df = pd.DataFrame(rows)
    if kind == "pandas":
        return df
    if kind == "polars":  # pragma: no cover - optional
        import polars as pl  # type: ignore[import-not-found]

        return pl.from_pandas(df)
    if kind == "arrow":  # pragma: no cover - optional
        import pyarrow as pa

        return pa.Table.from_pandas(df, preserve_index=False)
    return df  # plain FirmPanel input → pandas DataFrame output


[docs] def batch_fit( df: Any, *, method: str = "vassalou_xing", mapping: dict[str, str] | None = None, n_jobs: int = -1, dispatch: str = "joblib", chunk_size: int | None = None, progress: bool = False, on_error: str = "warn", horizon: float | None = None, **fit_kwargs: Any, ) -> Any: """Fit a Merton model to every firm in a panel. Parameters ---------- df Either a :class:`FirmPanel`, pandas DataFrame, polars DataFrame, or pyarrow Table containing one row per (firm, snapshot). Required columns: ``equity``, ``debt_short``, ``debt_long``. Optional: ``equity_vol``, ``rf``, ``dividend_yield``, ``horizon``, ``ticker``. method Calibration method (any name in :func:`merton.calibration.available_methods`). mapping Optional column-rename map applied before validation. n_jobs joblib workers. ``-1`` = all logical cores. dispatch ``"joblib"`` (default), ``"sequential"``, ``"dask"``, or ``"ray"``. chunk_size Chunk size for the dispatch. Default: ``max(1, len(panel)//n_jobs)``. (Currently advisory — joblib chooses chunk sizes automatically.) progress Render a Rich progress bar. on_error ``"warn"`` (default) emits a warning and continues with NaNs. ``"raise"`` propagates the first exception. ``"skip"`` drops the row. horizon Override the horizon for every firm (handy when the input panel doesn't carry one). **fit_kwargs Forwarded to :class:`MertonModel` (e.g. ``tol``, ``max_iter``, ``physical_measure``, ``sharpe_ratio``, ``n_bootstrap``). """ panel, out_kind = _panelify(df, mapping) if len(panel) == 0: return _format_output([], out_kind) model_template = {"method": method, **fit_kwargs} def _fit_one(firm: Firm) -> dict[str, Any]: if horizon is not None: firm = firm.replace(horizon=horizon) try: res = MertonModel(**model_template).fit(firm) except Exception as err: if on_error == "raise": raise if on_error == "warn": warnings.warn( f"batch_fit: {type(err).__name__}: {err} (firm={firm.ticker})", stacklevel=3, ) return { "ticker": firm.ticker, "dd": np.nan, "pd": np.nan, "asset_value": np.nan, "asset_vol": np.nan, "method": method, "converged": False, } return res.to_dict() firms = list(panel) results = parallel_map( _fit_one, firms, dispatch=dispatch, n_jobs=n_jobs, progress=progress, description=f"merton.batch_fit({method})", ) rows = [r for r in results if not (on_error == "skip" and r.get("converged") is False)] return _format_output(rows, out_kind)
__all__ = ["batch_fit"]