Source code for srvar.ssvs
from __future__ import annotations
import numpy as np
from .linalg import solve_psd
[docs]
def v0_diag_from_gamma(
*,
gamma: np.ndarray,
spike_var: float,
slab_var: float,
intercept_slab_var: float | None = None,
) -> np.ndarray:
"""Compute the diagonal of V0 implied by spike-and-slab indicators.
Parameters
----------
gamma:
Boolean inclusion indicators of shape ``(K,)``.
spike_var:
Variance for excluded predictors.
slab_var:
Variance for included predictors.
intercept_slab_var:
Optional override for the intercept variance (index 0).
Returns
-------
np.ndarray
Vector of length ``K`` representing the diagonal of ``V0``.
"""
g = np.asarray(gamma, dtype=bool)
if g.ndim != 1:
raise ValueError("gamma must be a 1D array")
if spike_var <= 0 or slab_var <= 0:
raise ValueError("spike_var and slab_var must be > 0")
v = np.where(g, float(slab_var), float(spike_var)).astype(float, copy=False)
if intercept_slab_var is not None:
if intercept_slab_var <= 0:
raise ValueError("intercept_slab_var must be > 0")
if v.size < 1:
raise ValueError("gamma must be non-empty when intercept_slab_var is provided")
v[0] = float(intercept_slab_var)
return v
[docs]
def sample_gamma_rows(
*,
beta: np.ndarray,
sigma: np.ndarray,
gamma: np.ndarray,
spike_var: float,
slab_var: float,
inclusion_prob: float,
fixed_mask: np.ndarray | None = None,
rng: np.random.Generator,
) -> np.ndarray:
"""Sample SSVS inclusion indicators for coefficient rows.
This updates ``gamma`` given the current coefficient draw ``beta`` and covariance
``sigma`` under a spike-and-slab prior.
Parameters
----------
beta:
VAR coefficient matrix of shape ``(K, N)``.
sigma:
Innovation covariance matrix of shape ``(N, N)``.
gamma:
Current inclusion indicators of shape ``(K,)``.
spike_var, slab_var:
Spike-and-slab prior variances.
inclusion_prob:
Prior inclusion probability.
fixed_mask:
Optional boolean mask indicating predictors which are forced to stay included.
rng:
NumPy RNG.
Returns
-------
np.ndarray
Updated boolean indicators of shape ``(K,)``.
"""
b = np.asarray(beta, dtype=float)
s = np.asarray(sigma, dtype=float)
g = np.asarray(gamma, dtype=bool)
if b.ndim != 2:
raise ValueError("beta must be 2D (K, N)")
k, n = b.shape
if s.shape != (n, n):
raise ValueError("sigma must have shape (N, N)")
if g.shape != (k,):
raise ValueError("gamma must have shape (K,)")
if spike_var <= 0 or slab_var <= 0:
raise ValueError("spike_var and slab_var must be > 0")
if not (0.0 < inclusion_prob < 1.0):
raise ValueError("inclusion_prob must be in (0, 1)")
fixed: np.ndarray
if fixed_mask is None:
fixed = np.zeros(k, dtype=bool)
else:
fixed = np.asarray(fixed_mask, dtype=bool)
if fixed.shape != (k,):
raise ValueError("fixed_mask must have shape (K,)")
inv_sigma = solve_psd(s, np.eye(n, dtype=float))
log_prior_odds = float(np.log(inclusion_prob) - np.log(1.0 - inclusion_prob))
log_det_ratio = float((n / 2.0) * np.log(slab_var / spike_var))
quad_coef = float(0.5 * (1.0 / slab_var - 1.0 / spike_var))
out = g.copy()
for r in range(k):
if fixed[r]:
continue
br = b[r, :]
q = float(br @ inv_sigma @ br)
logit = log_prior_odds - log_det_ratio - quad_coef * q
if logit >= 0:
p1 = 1.0 / (1.0 + np.exp(-logit))
else:
e = np.exp(logit)
p1 = float(e / (1.0 + e))
out[r] = bool(rng.uniform() < p1)
return out