Scaling out: Spark and Dask¶
merton.batch_fit is the right call up to ~100k firms × ~10 years of daily
data on a single workstation. Past that, you’ll want a cluster. This page
shows two patterns that bolt cleanly onto the package without adding either
Spark or Dask as a hard dependency.
Dask¶
Dask plays well with our pandas-native interface — every Firm row is fit
independently, so a delayed map over a Dask DataFrame is enough:
import pandas as pd
import dask.dataframe as dd
from merton import Firm, MertonModel
def fit_row(row: pd.Series) -> pd.Series:
firm = Firm(
equity=row["equity"],
debt_short=row["debt_short"],
debt_long=row["debt_long"],
equity_vol=row["equity_vol"],
rf=row.get("rf", 0.04),
horizon=row.get("horizon", 1.0),
ticker=row.get("ticker"),
)
res = MertonModel(method="vassalou_xing").fit(firm)
return pd.Series({"dd": res.dd, "pd": res.pd, "asset_vol": res.asset_vol})
panel = pd.read_parquet("s3://my-bucket/balance_sheets_2024.parquet")
ddf = dd.from_pandas(panel, npartitions=64)
out = ddf.map_partitions(
lambda part: part.join(part.apply(fit_row, axis=1)),
meta={**panel.dtypes.to_dict(), "dd": float, "pd": float, "asset_vol": float},
)
out.to_parquet("s3://my-bucket/dd_pd_2024.parquet", compute=True)
Tips:
Use
npartitions ≈ cores * 4for CPU-bound calibration loops; Numba releases the GIL inside the hot path, so threads scale well on Dask worker processes.Stash a pre-warmed Numba cache on each worker by calling
merton.warm_cache()once in the workersetup_hook.Pin
MERTON_BACKEND=numbain the worker env so each task uses the JIT kernels without paying for a runtime size-heuristic dispatch.
Apache Spark (PySpark)¶
Spark’s mapInPandas lets you push a pandas DataFrame through merton in
the executor JVM, the same way you’d run a pandas UDF:
import pandas as pd
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
from merton import Firm, MertonModel
schema = StructType([
StructField("ticker", StringType()),
StructField("dd", DoubleType()),
StructField("pd", DoubleType()),
StructField("asset_vol", DoubleType()),
])
def fit_chunk(iter_pdf):
model = MertonModel(method="vassalou_xing")
for pdf in iter_pdf:
rows = []
for r in pdf.itertuples(index=False):
firm = Firm(
equity=r.equity, debt_short=r.debt_short, debt_long=r.debt_long,
equity_vol=r.equity_vol, rf=r.rf, horizon=r.horizon, ticker=r.ticker,
)
res = model.fit(firm)
rows.append((r.ticker, res.dd, res.pd, res.asset_vol))
yield pd.DataFrame(rows, columns=["ticker", "dd", "pd", "asset_vol"])
spark_df = spark.read.parquet("s3://my-bucket/panel.parquet")
out = spark_df.mapInPandas(fit_chunk, schema=schema)
out.write.mode("overwrite").parquet("s3://my-bucket/dd_pd.parquet")
If your panel is partitioned by firm_id × date, set
spark.sql.execution.arrow.maxRecordsPerBatch to align with the natural
panel chunk size (1k–10k rows works well — large enough to amortise Python
startup, small enough to keep memory bounded).
Putting it together with climate stress¶
You can stack a climate overlay inside the worker function — the
ClimateOverlay instance is picklable so Spark/Dask broadcast it cleanly:
from merton.extensions import ClimateOverlay
from merton.scenarios.predefined.ngfs import delayed_transition
scenario = delayed_transition()
def stress_row(row):
firm = Firm(...)
overlay = ClimateOverlay(MertonModel(), scenario=scenario, sector=row["sector"])
return overlay.fit(firm).pd
For very large climate-stress runs (millions of firm-scenario combinations), pre-compute the per-sector writedown table once on the driver and broadcast it; the executor only has to do the structural calibration.