Source code for srvar.plotting

from __future__ import annotations

from collections.abc import Iterable
from typing import Any

import numpy as np

from .evaluation_common import (
    extract_series_draws,
    interval_hit_vector,
    mean_of_finite,
    pit_vector,
    score_vector_from_draws,
)
from .metrics import crps_draws
from .results import FitResult, ForecastResult
from .theme import (
    DEFAULT_THEME,
    Theme,
    get_alpha,
    get_color,
    get_figsize,
    get_linewidth,
    srvar_style,
)


def _require_matplotlib() -> Any:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except Exception as e:  # pragma: no cover
        raise ImportError("matplotlib is required; install with 'srvar-toolkit[plot]'") from e
    return plt


def _time_index_to_array(index: Any) -> np.ndarray:
    try:
        if hasattr(index, "to_numpy"):
            return np.asarray(index.to_numpy())
        return np.asarray(index)
    except Exception:
        return np.arange(len(index), dtype=int)


def _is_datetime_like(x: np.ndarray) -> bool:
    return bool(np.issubdtype(np.asarray(x).dtype, np.datetime64))


def _coverage_curve(
    forecasts: list[ForecastResult],
    y_true: np.ndarray,
    *,
    horizon_indices: list[int],
    intervals: list[float],
    var_index: int | None,
    use_latent: bool,
) -> dict[float, np.ndarray]:
    yt = np.asarray(y_true, dtype=float)
    cov_by_int: dict[float, np.ndarray] = {}
    for c in intervals:
        cov = np.full(len(horizon_indices), np.nan, dtype=float)
        for hh, h in enumerate(horizon_indices):
            if var_index is None:
                hits_all = [
                    interval_hit_vector(
                        yt[:, h, j],
                        extract_series_draws(
                            forecasts,
                            horizon_index=h,
                            var_index=j,
                            use_latent=use_latent,
                        ),
                        interval=c,
                    )
                    for j in range(int(yt.shape[2]))
                ]
                cov[hh] = mean_of_finite(np.concatenate(hits_all))
            else:
                cov[hh] = mean_of_finite(
                    interval_hit_vector(
                        yt[:, h, var_index],
                        extract_series_draws(
                            forecasts,
                            horizon_index=h,
                            var_index=var_index,
                            use_latent=use_latent,
                        ),
                        interval=c,
                    )
                )
        cov_by_int[c] = cov
    return cov_by_int


def _pit_values(
    forecasts: list[ForecastResult],
    y_true: np.ndarray,
    *,
    horizon_index: int,
    var_index: int,
    use_latent: bool,
) -> np.ndarray:
    yt = np.asarray(y_true, dtype=float)
    values = pit_vector(
        yt[:, horizon_index, var_index],
        extract_series_draws(
            forecasts,
            horizon_index=horizon_index,
            var_index=var_index,
            use_latent=use_latent,
        ),
    )
    return values[np.isfinite(values)]


def _crps_curve(
    forecasts: list[ForecastResult],
    y_true: np.ndarray,
    *,
    horizon_indices: list[int],
    var_index: int | None,
    use_latent: bool,
) -> np.ndarray:
    yt = np.asarray(y_true, dtype=float)
    out = np.full(len(horizon_indices), np.nan, dtype=float)
    for hh, h in enumerate(horizon_indices):
        if var_index is None:
            vals_all = [
                score_vector_from_draws(
                    yt[:, h, j],
                    extract_series_draws(
                        forecasts,
                        horizon_index=h,
                        var_index=j,
                        use_latent=use_latent,
                    ),
                    crps_draws,
                )
                for j in range(int(yt.shape[2]))
            ]
            out[hh] = mean_of_finite(np.concatenate(vals_all))
        else:
            out[hh] = mean_of_finite(
                score_vector_from_draws(
                    yt[:, h, var_index],
                    extract_series_draws(
                        forecasts,
                        horizon_index=h,
                        var_index=var_index,
                        use_latent=use_latent,
                    ),
                    crps_draws,
                )
            )
    return out


[docs] def plot_shadow_rate( fit: FitResult, *, var: str, bands: tuple[float, float] = (0.1, 0.9), ax: Any | None = None, overlays: dict[str, Iterable[float]] | None = None, show_observed: bool = True, theme: Theme | None = None, ) -> tuple[Any, Any]: """Plot observed vs. latent shadow-rate series. Parameters ---------- fit: Output from :func:`srvar.api.fit`. var: Variable name to plot. bands: Quantile band (low, high) for uncertainty visualization when latent draws are available. ax: Optional Matplotlib axis to draw into. overlays: Optional additional named series to overlay (e.g., benchmark rates). show_observed: Whether to plot the observed (censored) series. theme: Optional theme for styling. If None, uses DEFAULT_THEME. Returns ------- (fig, ax) Matplotlib figure and axis. """ if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() idx = fit.dataset.variables.index(var) x = _time_index_to_array(fit.dataset.time_index) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("wide", theme)) else: fig = ax.figure if show_observed: ax.plot( x, fit.dataset.values[:, idx], color=get_color("observed", theme), lw=get_linewidth("data", theme), label="Observed", ) if fit.latent_draws is not None: y = fit.latent_draws[:, :, idx] qlo, qhi = bands lo = np.quantile(y, q=qlo, axis=0) med = np.quantile(y, q=0.5, axis=0) hi = np.quantile(y, q=qhi, axis=0) ax.fill_between( x, lo, hi, color=get_color("shadow", theme), alpha=get_alpha("band", theme), lw=0 ) ax.plot( x, med, color=get_color("shadow", theme), lw=get_linewidth("median", theme), label="Shadow (median)", ) elif fit.latent_dataset is not None: ax.plot( x, fit.latent_dataset.values[:, idx], color=get_color("shadow", theme), lw=get_linewidth("median", theme), label="Shadow", ) if overlays is not None: for name, series in overlays.items(): ax.plot( x, np.asarray(list(series), dtype=float), lw=get_linewidth("data", theme), ls="--", label=name, ) ax.set_title(f"{var}: observed vs shadow") ax.legend(loc="best") if _is_datetime_like(np.asarray(x)): fig.autofmt_xdate() fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_forecast_coverage( forecasts: list[ForecastResult], y_true: np.ndarray, *, intervals: list[float] | None = None, horizons: list[int] | None = None, var: str | None = None, ax: Any | None = None, use_latent: bool = False, theme: Theme | None = None, ) -> tuple[Any, Any]: if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() if len(forecasts) < 1: raise ValueError("forecasts must be non-empty") yt = np.asarray(y_true, dtype=float) if yt.ndim != 3: raise ValueError("y_true must have shape (K, H, N)") hmax = int(np.asarray(forecasts[0].draws).shape[1]) n = int(np.asarray(forecasts[0].draws).shape[2]) k = int(yt.shape[0]) if yt.shape[1] != hmax or yt.shape[2] != n: raise ValueError("y_true shape must match forecasts: (K, H, N)") if len(forecasts) != k: raise ValueError("len(forecasts) must equal y_true.shape[0]") if horizons is None: h_list = list(range(1, hmax + 1)) else: if (not isinstance(horizons, list)) or (len(horizons) < 1): raise ValueError("horizons must be a non-empty list[int]") h_list = [int(h) for h in horizons] if any(h < 1 or h > hmax for h in h_list): raise ValueError("horizons contains values out of range") if len(set(h_list)) != len(h_list): raise ValueError("horizons must not contain duplicates") h_list = sorted(h_list) h_idx = [h - 1 for h in h_list] if var is None: var_idx = None title = "Forecast coverage (avg across variables)" else: if var not in forecasts[0].variables: raise ValueError("var not in forecasts[0].variables") var_idx = int(forecasts[0].variables.index(var)) title = f"Forecast coverage: {var}" if intervals is None: intervals = [0.5, 0.8, 0.9] intervals_f = [float(c) for c in intervals] for c in intervals_f: if not np.isfinite(c) or not (0.0 < c < 1.0): raise ValueError("intervals must be in (0, 1)") cov_by_int = _coverage_curve( forecasts, yt, horizon_indices=h_idx, intervals=intervals_f, var_index=var_idx, use_latent=use_latent, ) x = np.asarray(h_list, dtype=int) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("wide", theme)) else: fig = ax.figure for _j, c in enumerate(sorted(cov_by_int.keys())): ax.plot( x, cov_by_int[c], lw=get_linewidth("median", theme), label=f"{int(round(100 * c))}%" ) ax.set_ylim(0.0, 1.0) ax.set_xlabel("Horizon") ax.set_ylabel("Empirical coverage") ax.set_title(title) ax.legend(loc="best") fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_pit_histogram( forecasts: list[ForecastResult], y_true: np.ndarray, *, var: str, horizon: int, bins: int = 10, ax: Any | None = None, use_latent: bool = False, theme: Theme | None = None, ) -> tuple[Any, Any]: if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() if len(forecasts) < 1: raise ValueError("forecasts must be non-empty") yt = np.asarray(y_true, dtype=float) if yt.ndim != 3: raise ValueError("y_true must have shape (K, H, N)") if var not in forecasts[0].variables: raise ValueError("var not in forecasts[0].variables") var_idx = int(forecasts[0].variables.index(var)) hmax = int(np.asarray(forecasts[0].draws).shape[1]) if horizon < 1 or horizon > hmax: raise ValueError("horizon out of range") h_idx = int(horizon - 1) k = int(yt.shape[0]) if len(forecasts) != k: raise ValueError("len(forecasts) must equal y_true.shape[0]") if yt.shape[1] != hmax: raise ValueError("y_true horizon dimension must match forecasts") u = _pit_values( forecasts, yt, horizon_index=h_idx, var_index=var_idx, use_latent=use_latent, ) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("single", theme)) else: fig = ax.figure ax.hist( u, bins=int(bins), range=(0.0, 1.0), density=True, color=get_color("pit", theme), alpha=get_alpha("bar", theme), ) ax.axhline( 1.0, color=get_color("reference", theme), lw=get_linewidth("reference", theme), ls="--", ) ax.set_xlim(0.0, 1.0) ax.set_xlabel("PIT") ax.set_ylabel("Density") ax.set_title(f"PIT histogram: {var}, h={horizon}") fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_crps_by_horizon( forecasts: list[ForecastResult], y_true: np.ndarray, *, horizons: list[int] | None = None, var: str | None = None, ax: Any | None = None, use_latent: bool = False, theme: Theme | None = None, ) -> tuple[Any, Any]: if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() if len(forecasts) < 1: raise ValueError("forecasts must be non-empty") yt = np.asarray(y_true, dtype=float) if yt.ndim != 3: raise ValueError("y_true must have shape (K, H, N)") hmax = int(np.asarray(forecasts[0].draws).shape[1]) n = int(np.asarray(forecasts[0].draws).shape[2]) k = int(yt.shape[0]) if len(forecasts) != k or yt.shape[1] != hmax or yt.shape[2] != n: raise ValueError("y_true shape must match forecasts") if var is None: var_idx = None title = "CRPS by horizon (avg across variables)" else: if var not in forecasts[0].variables: raise ValueError("var not in forecasts[0].variables") var_idx = int(forecasts[0].variables.index(var)) title = f"CRPS by horizon: {var}" if horizons is None: h_list = list(range(1, hmax + 1)) else: if (not isinstance(horizons, list)) or (len(horizons) < 1): raise ValueError("horizons must be a non-empty list[int]") h_list = [int(h) for h in horizons] if any(h < 1 or h > hmax for h in h_list): raise ValueError("horizons contains values out of range") if len(set(h_list)) != len(h_list): raise ValueError("horizons must not contain duplicates") h_list = sorted(h_list) h_idx = [h - 1 for h in h_list] crps_mean = _crps_curve( forecasts, yt, horizon_indices=h_idx, var_index=var_idx, use_latent=use_latent, ) x = np.asarray(h_list, dtype=int) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("wide", theme)) else: fig = ax.figure ax.plot(x, crps_mean, lw=get_linewidth("median", theme), color=get_color("crps", theme)) ax.set_xlabel("Horizon") ax.set_ylabel("CRPS") ax.set_title(title) fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_forecast_fanchart( fc: ForecastResult, *, var: str, bands: tuple[float, float] = (0.1, 0.9), ax: Any | None = None, use_latent: bool = False, theme: Theme | None = None, ) -> tuple[Any, Any]: """Plot a forecast fan chart from predictive simulations. Parameters ---------- fc: Output from :func:`srvar.api.forecast`. var: Variable name to plot. bands: Quantile band (low, high) for the fan. ax: Optional Matplotlib axis. use_latent: If True and latent draws exist (ELB model), use latent draws instead of floored observed draws. theme: Optional theme for styling. If None, uses DEFAULT_THEME. Returns ------- (fig, ax) Matplotlib figure and axis. """ if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() idx = fc.variables.index(var) sims = fc.latent_draws if (use_latent and fc.latent_draws is not None) else fc.draws x = np.arange(1, sims.shape[1] + 1, dtype=int) y = sims[:, :, idx] qlo, qhi = bands lo = np.quantile(y, q=qlo, axis=0) med = np.quantile(y, q=0.5, axis=0) hi = np.quantile(y, q=qhi, axis=0) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("single", theme)) else: fig = ax.figure ax.fill_between( x, lo, hi, color=get_color("forecast", theme), alpha=get_alpha("band", theme), lw=0 ) ax.plot( x, med, color=get_color("forecast", theme), lw=get_linewidth("median", theme), label="Median", ) ax.set_title(f"Forecast fan chart: {var}") ax.set_xlabel("Horizon") ax.legend(loc="best") fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_volatility( fit: FitResult, *, var: str, bands: tuple[float, float] = (0.1, 0.9), ax: Any | None = None, theme: Theme | None = None, ) -> tuple[Any, Any]: """Plot stochastic volatility (posterior std dev) for a given series. Parameters ---------- fit: Output from :func:`srvar.api.fit` with volatility enabled. var: Variable name. bands: Quantile band (low, high) used to summarize posterior uncertainty. ax: Optional Matplotlib axis. theme: Optional theme for styling. If None, uses DEFAULT_THEME. Returns ------- (fig, ax) Matplotlib figure and axis. """ if fit.h_draws is None: raise ValueError("fit.h_draws is required (fit must be run with volatility enabled)") if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() idx = fit.dataset.variables.index(var) p = fit.model.p x = _time_index_to_array(fit.dataset.time_index)[p:] h = fit.h_draws[:, :, idx] sd = np.exp(0.5 * h) qlo, qhi = bands lo = np.quantile(sd, q=qlo, axis=0) med = np.quantile(sd, q=0.5, axis=0) hi = np.quantile(sd, q=qhi, axis=0) if ax is None: fig, ax = plt.subplots(figsize=get_figsize("wide", theme)) else: fig = ax.figure ax.fill_between( x, lo, hi, color=get_color("volatility", theme), alpha=get_alpha("band", theme), lw=0 ) ax.plot( x, med, color=get_color("volatility", theme), lw=get_linewidth("median", theme), label="Median", ) ax.set_title(f"Stochastic volatility (std dev): {var}") ax.legend(loc="best") if _is_datetime_like(np.asarray(x)): fig.autofmt_xdate() fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_ssvs_inclusion( fit: FitResult, *, ax: Any | None = None, theme: Theme | None = None, ) -> tuple[Any, Any]: """Plot posterior inclusion probabilities for SSVS. Parameters ---------- fit: Output from :func:`srvar.api.fit` with ``prior.family='ssvs'``. ax: Optional Matplotlib axis. theme: Optional theme for styling. If None, uses DEFAULT_THEME. Returns ------- (fig, ax) Matplotlib figure and axis. """ if fit.gamma_draws is None: raise ValueError("fit.gamma_draws is required (fit must be run with prior.family='ssvs')") if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() probs = fit.gamma_draws.mean(axis=0) k = probs.shape[0] if fit.model.include_intercept: names = ["intercept"] else: names = [] n = fit.dataset.N for lag in range(1, fit.model.p + 1): for j in range(n): names.append(f"lag{lag}:{fit.dataset.variables[j]}") if len(names) != k: names = [f"x{i}" for i in range(k)] if ax is None: fig, ax = plt.subplots(figsize=(10, max(3, 0.25 * k))) else: fig = ax.figure y = np.arange(k, dtype=int) ax.barh(y, probs, color=get_color("inclusion", theme), alpha=get_alpha("bar", theme)) ax.set_yticks(y) ax.set_yticklabels(names) ax.set_xlim(0.0, 1.0) ax.invert_yaxis() ax.set_title("SSVS inclusion probabilities") fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax
[docs] def plot_trace( draws: np.ndarray, *, ax: Any | None = None, label: str | None = None, theme: Theme | None = None, ) -> tuple[Any, Any]: """Plot a simple MCMC trace for a 1D parameter draw sequence. Parameters ---------- draws: 1D array of MCMC draws. ax: Optional Matplotlib axis. label: Optional title for the plot. theme: Optional theme for styling. If None, uses DEFAULT_THEME. Returns ------- (fig, ax) Matplotlib figure and axis. """ if theme is None: theme = DEFAULT_THEME with srvar_style(theme): plt = _require_matplotlib() x = np.asarray(draws, dtype=float) if x.ndim != 1: raise ValueError("draws must be 1D") if ax is None: fig, ax = plt.subplots(figsize=get_figsize("single", theme)) else: fig = ax.figure ax.plot(np.arange(x.size, dtype=int), x, lw=get_linewidth("data", theme)) if label is not None: ax.set_title(label) fig.tight_layout(pad=theme.layout.tight_layout_pad) return fig, ax