Source code for merton.cli.commands.doctor
"""``merton doctor`` — print a runtime diagnostic.
Reports:
- Python version and free-threaded (no-GIL) status.
- Installed array backends and which one ``auto`` would resolve to.
- GPU / CUDA / Apple-Silicon / JAX device availability.
- Versions of the major dependencies.
- Hints for missing optional extras.
"""
from __future__ import annotations
import importlib
import importlib.metadata
import platform
import sys
from importlib.util import find_spec
import typer
from rich.console import Console
from rich.table import Table
from ... import __version__
from ..._backend import _registry
def _is_free_threaded() -> bool:
"""Return True on a free-threaded (PEP 703) build of Python 3.13+/3.14+."""
# 3.13+ exposes sys._is_gil_enabled(); older interpreters always have the GIL.
fn = getattr(sys, "_is_gil_enabled", None)
if fn is None:
return False
try:
return not fn()
except Exception: # pragma: no cover
return False
def _safe_version(pkg: str) -> str:
try:
return importlib.metadata.version(pkg)
except importlib.metadata.PackageNotFoundError:
return "—"
def _cuda_devices() -> list[str]:
"""Best-effort enumeration of installed NVIDIA GPUs via CuPy."""
if find_spec("cupy") is None:
return []
try:
cp = importlib.import_module("cupy")
n = cp.cuda.runtime.getDeviceCount()
return [cp.cuda.runtime.getDeviceProperties(i)["name"].decode() for i in range(n)]
except Exception: # pragma: no cover
return []
def _jax_devices() -> list[str]:
if find_spec("jax") is None:
return []
try:
jax = importlib.import_module("jax")
return [str(d) for d in jax.devices()]
except Exception: # pragma: no cover
return []
def _mlx_available() -> bool:
return find_spec("mlx") is not None and platform.system() == "Darwin"
[docs]
def run() -> None:
"""Print runtime diagnostics."""
console.rule(f"[bold]merton {__version__}")
# Python / build
info = Table(show_header=False, box=None)
info.add_column("k", style="bold cyan")
info.add_column("v")
info.add_row("Python", f"{platform.python_version()} ({platform.python_implementation()})")
info.add_row("Free-threaded (no-GIL)", "yes" if _is_free_threaded() else "no")
info.add_row("Platform", f"{platform.system()} {platform.release()} ({platform.machine()})")
console.print(info)
# Backends
backends = Table(title="Array backends", title_justify="left")
backends.add_column("Backend")
backends.add_column("Installed")
backends.add_column("Notes")
avail = set(_registry.available())
default = _registry.default_backend()
for name in ("numpy", "numba", "cupy", "jax", "mlx"):
is_installed = name in avail
note_parts = []
if name == default:
note_parts.append("[bold green]default[/]")
if name == "cupy":
devs = _cuda_devices()
if devs:
note_parts.append(f"GPUs: {', '.join(devs)}")
if name == "jax":
jdevs = _jax_devices()
if jdevs:
note_parts.append(f"devices: {', '.join(jdevs)}")
if name == "mlx" and not _mlx_available() and is_installed:
note_parts.append("only useful on Apple Silicon")
backends.add_row(
name,
"[green]✓[/]" if is_installed else "[dim]—[/]",
" · ".join(note_parts) or "",
)
console.print(backends)
# Key dependencies
deps = Table(title="Key dependencies", title_justify="left")
deps.add_column("Package")
deps.add_column("Version")
for pkg in (
"numpy",
"scipy",
"pandas",
"pyarrow",
"numba",
"joblib",
"structlog",
"pydantic",
"pydantic-settings",
"typer",
"rich",
"jax",
"jaxlib",
"cupy-cuda12x",
"mlx",
"xlwings",
):
deps.add_row(pkg, _safe_version(pkg))
console.print(deps)
# Hints
hints: list[str] = []
if "jax" not in avail:
hints.append(r"Install [bold]merton\[jax][/] for autodiff Greeks and GPU/TPU acceleration.")
if "cupy" not in avail and platform.system() != "Darwin":
hints.append(r"Install [bold]merton\[gpu][/] (CUDA 12) for GPU-accelerated panels.")
if platform.system() == "Darwin" and platform.machine() == "arm64" and "mlx" not in avail:
hints.append(
r"Install [bold]merton\[mlx][/] for Metal-accelerated kernels on Apple Silicon."
)
if not _is_free_threaded() and sys.version_info >= (3, 13):
hints.append("Try a [bold]free-threaded[/] Python (cp313t / cp314t) for parallel panels.")
if hints:
console.print("\n[bold]Suggestions[/]")
for h in hints:
console.print(f" • {h}")
def _entrypoint(ctx: typer.Context) -> None: # pragma: no cover - typer adapter
run()