Source code for srvar.theme

"""Visual grammar and theme management for srvar plots.

This module provides a consistent visual style across all srvar plotting functions,
with semantic colour names, typography settings, and layout constants.

Usage
-----
>>> from srvar.theme import srvar_style, get_color
>>> with srvar_style():
...     fig, ax = plt.subplots()
...     ax.plot(x, y, color=get_color("forecast"))

>>> from srvar.theme import apply_srvar_style
>>> apply_srvar_style()  # Apply globally
"""

from __future__ import annotations

from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any


[docs] @dataclass(frozen=True, slots=True) class Palette: """Colour palette for srvar plots. Attributes ---------- observed : str Colour for observed data series. shadow : str Colour for shadow/latent series. forecast : str Colour for forecast visualisations. volatility : str Colour for stochastic volatility plots. inclusion : str Colour for SSVS inclusion probability plots. coverage : str Colour for coverage plot lines. pit : str Colour for PIT histogram bars. crps : str Colour for CRPS line plots. band_fill : str Base colour for uncertainty bands (alpha applied separately). reference : str Colour for reference/nominal lines. grid : str Colour for grid lines. spine : str Colour for axis spines. text : str Colour for text elements. """ # Primary semantic colours observed: str = "#2E86AB" # Blue - observed data shadow: str = "#A23B72" # Magenta - shadow/latent forecast: str = "#F18F01" # Orange - forecasts volatility: str = "#C73E1D" # Red - volatility inclusion: str = "#6B4226" # Brown - SSVS inclusion coverage: str = "#2E86AB" # Blue - coverage plots pit: str = "#2E86AB" # Blue - PIT histogram crps: str = "#1B998B" # Teal - CRPS plots # Uncertainty bands band_fill: str = "#2E86AB" # Reference elements reference: str = "#888888" # Grey - reference/nominal lines # Grid and axes grid: str = "#E0E0E0" spine: str = "#333333" text: str = "#1A1A1A" @property def sequential(self) -> list[str]: """Sequential palette for multiple series.""" return [ self.observed, self.shadow, self.forecast, self.volatility, self.inclusion, self.crps, ]
[docs] @dataclass(frozen=True, slots=True) class Typography: """Font settings for srvar plots. Attributes ---------- family : str Font family. title_size : float Font size for plot titles. label_size : float Font size for axis labels. tick_size : float Font size for tick labels. legend_size : float Font size for legend text. annotation_size : float Font size for annotations. title_weight : str Font weight for titles. label_weight : str Font weight for labels. """ family: str = "sans-serif" title_size: float = 11.0 label_size: float = 9.0 tick_size: float = 8.0 legend_size: float = 8.0 annotation_size: float = 7.5 title_weight: str = "semibold" label_weight: str = "normal"
[docs] @dataclass(frozen=True, slots=True) class Layout: """Layout constants for srvar plots. Attributes ---------- figure_single : tuple[float, float] Default figure size for single plots. figure_wide : tuple[float, float] Figure size for wide plots. figure_square : tuple[float, float] Figure size for square plots. figure_panel : tuple[float, float] Figure size for multi-panel layouts. line_data : float Line width for data series. line_median : float Line width for median/summary lines. line_reference : float Line width for reference lines. line_grid : float Line width for grid lines. marker_size : float Default marker size. band_alpha : float Transparency for uncertainty bands. fill_alpha : float Transparency for filled regions. bar_alpha : float Transparency for bar plots. legend_frameon : bool Whether to show legend frame. tight_layout_pad : float Padding for tight_layout. dpi_display : int DPI for display. dpi_save : int DPI for saved figures. """ # Figure sizes (width, height) in inches figure_single: tuple[float, float] = (7.0, 3.5) figure_wide: tuple[float, float] = (10.0, 4.0) figure_square: tuple[float, float] = (5.0, 5.0) figure_panel: tuple[float, float] = (10.0, 8.0) # Line widths line_data: float = 1.5 line_median: float = 2.0 line_reference: float = 1.0 line_grid: float = 0.5 # Markers marker_size: float = 4.0 # Transparency band_alpha: float = 0.2 fill_alpha: float = 0.7 bar_alpha: float = 0.75 # Spacing legend_frameon: bool = False tight_layout_pad: float = 0.5 # Resolution dpi_display: int = 150 dpi_save: int = 300
[docs] @dataclass(frozen=True, slots=True) class Theme: """Complete theme specification for srvar plots. A theme combines palette, typography, and layout settings into a single configuration that can be applied to all plots. Attributes ---------- palette : Palette Colour palette settings. typography : Typography Font and text settings. layout : Layout Figure size and layout settings. name : str Theme name for identification. Examples -------- >>> theme = Theme() # Default theme >>> with srvar_style(theme): ... fig, ax = plt.subplots() >>> custom_theme = Theme( ... palette=Palette(observed="#000000"), ... typography=Typography(title_size=12.0), ... ) """ palette: Palette = field(default_factory=Palette) typography: Typography = field(default_factory=Typography) layout: Layout = field(default_factory=Layout) name: str = "default"
[docs] def to_rcparams(self) -> dict[str, Any]: """Convert theme to matplotlib rcParams dictionary. Returns ------- dict[str, Any] Dictionary suitable for `plt.rcParams.update()`. """ prop_cycle = self._prop_cycle() rc: dict[str, Any] = { # Figure "figure.figsize": self.layout.figure_single, "figure.dpi": self.layout.dpi_display, "figure.facecolor": "white", "figure.edgecolor": "white", # Axes "axes.facecolor": "white", "axes.edgecolor": self.palette.spine, "axes.linewidth": 0.8, "axes.grid": True, "axes.axisbelow": True, "axes.titlesize": self.typography.title_size, "axes.titleweight": self.typography.title_weight, "axes.labelsize": self.typography.label_size, "axes.labelweight": self.typography.label_weight, # Grid "grid.color": self.palette.grid, "grid.linewidth": self.layout.line_grid, "grid.alpha": 0.7, # Ticks "xtick.labelsize": self.typography.tick_size, "ytick.labelsize": self.typography.tick_size, "xtick.direction": "out", "ytick.direction": "out", "xtick.major.width": 0.8, "ytick.major.width": 0.8, # Legend "legend.fontsize": self.typography.legend_size, "legend.frameon": self.layout.legend_frameon, "legend.loc": "best", # Lines "lines.linewidth": self.layout.line_data, "lines.markersize": self.layout.marker_size, # Font "font.family": self.typography.family, "font.size": self.typography.label_size, # Savefig "savefig.dpi": self.layout.dpi_save, "savefig.bbox": "tight", "savefig.pad_inches": 0.1, "savefig.facecolor": "white", "savefig.edgecolor": "white", } if prop_cycle is not None: rc["axes.prop_cycle"] = prop_cycle return rc
def _prop_cycle(self) -> Any: """Create matplotlib property cycler from palette.""" try: from cycler import cycler return cycler(color=self.palette.sequential) except ImportError: # cycler is a matplotlib dependency, but handle gracefully return None
# ============================================================================= # Preset Themes # ============================================================================= def _colorblind_palette() -> Palette: """Create a colorblind-safe palette using IBM Design Language colours.""" return Palette( observed="#648FFF", # Blue shadow="#DC267F", # Magenta forecast="#FE6100", # Orange volatility="#785EF0", # Purple inclusion="#FFB000", # Gold coverage="#648FFF", pit="#648FFF", crps="#785EF0", band_fill="#648FFF", reference="#888888", grid="#E0E0E0", spine="#333333", text="#1A1A1A", ) def _print_typography() -> Typography: """Create typography settings optimised for print/publication.""" return Typography( family="serif", title_size=12.0, label_size=10.0, tick_size=9.0, legend_size=9.0, annotation_size=8.0, title_weight="bold", label_weight="normal", ) # Default theme instance DEFAULT_THEME = Theme() # Colorblind-safe theme COLORBLIND_THEME = Theme( palette=_colorblind_palette(), name="colorblind", ) # Print-friendly theme (larger fonts, serif) PRINT_THEME = Theme( typography=_print_typography(), name="print", ) # ============================================================================= # Context Manager and Application Functions # =============================================================================
[docs] @contextmanager def srvar_style(theme: Theme | None = None) -> Generator[Theme, None, None]: """Context manager to apply srvar visual style. Parameters ---------- theme : Theme | None Theme to apply. If None, uses DEFAULT_THEME. Yields ------ Theme The applied theme. Examples -------- >>> with srvar_style(): ... fig, ax = plt.subplots() ... ax.plot(x, y) >>> with srvar_style(COLORBLIND_THEME): ... fig, ax = plt.subplots() """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError("matplotlib is required; install with 'srvar-toolkit[plot]'") from e if theme is None: theme = DEFAULT_THEME # Save current rcParams original = plt.rcParams.copy() try: plt.rcParams.update(theme.to_rcparams()) yield theme finally: # Restore original rcParams plt.rcParams.update(original)
[docs] def apply_srvar_style(theme: Theme | None = None) -> None: """Globally apply srvar style to matplotlib. This modifies matplotlib's global rcParams. Use with caution in shared environments. Parameters ---------- theme : Theme | None Theme to apply. If None, uses DEFAULT_THEME. Examples -------- >>> from srvar.theme import apply_srvar_style >>> apply_srvar_style() # All subsequent plots use srvar style """ try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError("matplotlib is required; install with 'srvar-toolkit[plot]'") from e if theme is None: theme = DEFAULT_THEME plt.rcParams.update(theme.to_rcparams())
[docs] def reset_style() -> None: """Reset matplotlib to default style. This is useful after calling `apply_srvar_style()` to restore defaults. """ try: import matplotlib.pyplot as plt plt.rcdefaults() except ImportError: pass
# ============================================================================= # Convenience Accessors # =============================================================================
[docs] def get_color(name: str, theme: Theme | None = None) -> str: """Get a semantic colour by name. Parameters ---------- name : str Colour name (e.g., 'observed', 'shadow', 'forecast', 'volatility'). theme : Theme | None Theme to use. If None, uses DEFAULT_THEME. Returns ------- str Hex colour code. Examples -------- >>> get_color("forecast") '#F18F01' >>> get_color("shadow", COLORBLIND_THEME) '#DC267F' """ if theme is None: theme = DEFAULT_THEME return str(getattr(theme.palette, name))
[docs] def get_figsize(name: str = "single", theme: Theme | None = None) -> tuple[float, float]: """Get a figure size by name. Parameters ---------- name : str Size name: 'single', 'wide', 'square', or 'panel'. theme : Theme | None Theme to use. If None, uses DEFAULT_THEME. Returns ------- tuple[float, float] Figure size (width, height) in inches. Examples -------- >>> get_figsize("wide") (10.0, 4.0) """ if theme is None: theme = DEFAULT_THEME return tuple(getattr(theme.layout, f"figure_{name}"))
[docs] def get_alpha(name: str = "band", theme: Theme | None = None) -> float: """Get an alpha/transparency value by name. Parameters ---------- name : str Alpha name: 'band', 'fill', or 'bar'. theme : Theme | None Theme to use. If None, uses DEFAULT_THEME. Returns ------- float Alpha value between 0 and 1. """ if theme is None: theme = DEFAULT_THEME return float(getattr(theme.layout, f"{name}_alpha"))
[docs] def get_linewidth(name: str = "data", theme: Theme | None = None) -> float: """Get a line width by name. Parameters ---------- name : str Line width name: 'data', 'median', 'reference', or 'grid'. theme : Theme | None Theme to use. If None, uses DEFAULT_THEME. Returns ------- float Line width in points. """ if theme is None: theme = DEFAULT_THEME return float(getattr(theme.layout, f"line_{name}"))