Source code for mici.interop

"""Utilities for interfacing with external probabilistic programming libraries."""

from __future__ import annotations

import importlib
import os
from typing import TYPE_CHECKING

import numpy as np

import mici

if TYPE_CHECKING:
    from typing import Literal, Optional, Union

    import arviz
    import pymc
    import stan
    from numpy.typing import ArrayLike

    from mici.types import GradientFunction, ScalarFunction, TraceFunction


[docs]def convert_to_inference_data( traces: dict[str, list[ArrayLike]], stats: dict[str, list[ArrayLike]], energy_key: Optional[str] = "energy", lp_key: Optional[str] = "lp", ) -> arviz.InferenceData: """Convert Mici :code:`sample_chains` output to :py:class:`arviz.InferenceData`. Args: traces: Traces output from Mici :py:meth:`mici.samplers.MarkovChainMonteCarloMethod.sample_chains` call. A dictionary of variables traced over sampled chains with the dictionary keys the variable names and the values a list of arrays, one array per sampled chain, with the first array dimension corresponding to the draw index and any remaining dimensions, the variable dimensions. stats: Statistics output from Mici `sample_chains` call. A dictionary of chain statistics traced over sampled chains with the dictionary keys the statistics names and the values a list of arrays, one array per sampled chain, with the array dimension corresponding to the draw index. energy_key: The key of an entry in the `traces` dictionary corresponding the value of the Hamiltonian energy for the accepted proposal (up to an additive constant). If present the corresponding values will be added to the `sample_stats` group of the returned `InferenceData` object. lp_key: The key of an entry in the `traces` dictionary corresponding the value of the joint log posterior density for the model (up to an additive constant). If present the corresponding values will be added to the `sample_stats` group of the returned `InferenceData` object. Returns: ArviZ inference data object with traced chain data stored in the `posterior` group and additional chain statistics in the `sample_stats` group. """ import arviz stats = stats.copy() stats["n_steps"] = stats.pop("n_step") stats["acceptance_rate"] = stats.pop("accept_stat") if energy_key is not None and energy_key in traces: stats["energy"] = traces[energy_key] if lp_key is not None and lp_key in traces: stats["lp"] = traces[lp_key] return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset(stats, library=mici), )
[docs]def construct_pymc_model_functions( model: pymc.Model, ) -> tuple[ScalarFunction, GradientFunction, TraceFunction]: """Construct functions for sampling from PyMC model using Mici. Args: model: PyMC model to construct functions for. Returns: Tuple :code:`(neg_log_dens, grad_neg_log_dens, trace_func)` with :code:`neg_log_dens` a function for evaluating negative logarithm of unnormalized posterior density associated with model, :code:`grad_neg_log_dens` a function for evaluating gradient of :code:`neg_log_dens` with respect to position array argument and :code:`trace_func` a function which extract model parameter values from chain state for tracing during sampling. """ import pymc initial_point = model.initial_point() raveled_initial_point = pymc.blocking.DictToArrayBijection.map(initial_point) val_and_grad_log_dens = model.logp_dlogp_function() val_and_grad_log_dens.set_extra_values({}) def grad_neg_log_dens(pos): val, grad = val_and_grad_log_dens(pos) return -grad, -val def neg_log_dens(pos): val, _ = val_and_grad_log_dens(pos) return -val def trace_func(state): raveled_vars = pymc.blocking.RaveledVars( state.pos, raveled_initial_point.point_map_info, ) var_dict = pymc.blocking.DictToArrayBijection.rmap(raveled_vars) trace_dict = {} for rv in model.unobserved_RVs: if rv.name in var_dict: trace_dict[rv.name] = var_dict[rv.name] else: transform = model.rvs_to_transforms[rv] trace_dict[rv.name] = transform.backward( var_dict[f"{rv.name}_{transform.name}__"], *rv.owner.inputs, ).eval() trace_dict["lp"] = -neg_log_dens(state.pos) return trace_dict return neg_log_dens, grad_neg_log_dens, trace_func
[docs]def sample_pymc_model( draws: int = 1000, *, tune: int = 1000, chains: Optional[int] = None, cores: Optional[int] = None, random_seed: Optional[int] = None, progressbar: bool = True, init: Literal["auto", "adapt_diag", "jitter+adapt_diag", "adapt_full"] = "auto", jitter_max_retries: int = 10, return_inferencedata: bool = False, model: Optional[pymc.Model] = None, target_accept: float = 0.8, max_treedepth: int = 10, ) -> Union[arviz.InferenceData, dict[str, ArrayLike]]: """Generate approximate samples from posterior defined by a PyMC model. Uses dynamic multinomial HMC algorithm in Mici with adaptive warm-up phase. This function replicates the interface of the :py:func:`pymc.sample` function to allow using as a (partial) drop-in replacement. Args: draws: The number of samples to draw. tune: Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the `draws` argument, and will be discarded. chains: The number of chains to sample. Running independent chains is important for some convergence statistics and can also reveal multiple modes in the posterior. If :code::`None`, then set to either :code:`cores` or 2, whichever is larger. cores: The number of chains to run in parallel. If :code:`None`, set to the number of CPU cores in the system, but at most 4. random_seed: Seed for NumPy random number generator used for generating random variables while sampling chains. If :code:`None` then generator will be seeded with entropy from operating system. progressbar: Whether or not to display a progress bar. init: Initialization method to use. One of: * :code:`"adapt_diag"`: Start with a identity mass matrix and then adapt a diagonal based on the variance of the tuning samples. All chains use the test value (usually the prior mean) as starting point. * :code:`jitter+adapt_diag`: Same as :code:`"adapt_diag"`, but add uniform jitter in [-1, 1] to the starting point in each chain. Also chosen if :code:`init="auto"`. * :code:`"adapt_full"`: Adapt a dense mass matrix using the sample covariances. * :code:`jitter+adapt_full`: Same as :code:`"adapt_full"`, but add uniform jitter in [-1, 1] to the starting point in each chain.d jitter_max_retries: Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter that yields a finite probability. This applies to `"jitter+adapt_diag"` and :code:`"jitter+adapt_full"` :py:obj:`init` methods. return_inferencedata: Whether to return the traces as an :py:class:`arviz.InferenceData` (:code:`True`) object or a :py:class:`dict` (:code:`False`). model: PyMC model defining posterior distribution to sample from. May be :code:`None` if function is called from within model context manager. target_accept: Target value for the acceptance statistic being controlled during adaptive warm-up. max_treedepth: Maximum depth to expand trajectory binary tree to in integrator transition. The maximum number of integrator steps corresponds to :code:`2**max_treedepth`. Returns: A dictionary or :py:class:`arviz.InferenceData` object containing the sampled chain output. Dictionary output (when :code:`return_inferencedata=False`) has string keys corresponding to the name of each traced variable in the model, with the values being the corresponding values of the variables traced across the chains as NumPy arrays, with the first dimension the chain index (of size equal to :code:`chains`), the second dimension the draw index (of size equal to :code:`draws`) and any remaining dimensions corresponding to the dimensions of the traced variable. If :code:`return_inferencedata=True` an :py:class:`arviz.InferenceData` object is instead returned with traced chain data stored in the :code:`posterior` group and additional chain statistics in the :code:`sample_stats` group. """ import pymc if return_inferencedata and importlib.util.find_spec("arviz") is None: msg = "Cannot return InferenceData as ArviZ is not installed" raise ValueError(msg) model = pymc.modelcontext(model) # assume 2 threads per CPU core cores = min(4, os.cpu_count() // 2) if cores is None else cores chains = max(2, cores) if chains is None else chains init = "jitter+adapt_diag" if init == "auto" else init if init in ("jitter+adapt_diag", "jitter+adapt_full", "adapt_diag", "adapt_full"): use_dense_metric = "adapt_full" in init jitter_init = "jitter" in init else: msg = 'init must be "auto", "jitter+adapt_diag", "adapt_diag" or "adapt_full"' raise ValueError(msg) neg_log_dens, grad_neg_log_dens, trace_func = construct_pymc_model_functions(model) system = mici.systems.EuclideanMetricSystem( neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens, ) integrator = mici.integrators.LeapfrogIntegrator(system) rng = np.random.default_rng(random_seed) sampler = mici.samplers.DynamicMultinomialHMC( system, integrator, rng, max_tree_depth=max_treedepth, ) step_size_adapter = mici.adapters.DualAveragingStepSizeAdapter(target_accept) metric_adapter = ( mici.adapters.OnlineCovarianceMetricAdapter() if use_dense_metric else mici.adapters.OnlineVarianceMetricAdapter() ) initial_point = model.initial_point() raveled_initial_point = pymc.blocking.DictToArrayBijection.map(initial_point) if jitter_init: mean = raveled_initial_point.data.copy() init_states = [] for _c in range(chains): for _t in range(jitter_max_retries): pos = mean + rng.uniform(-1, 1, mean.shape) if np.isfinite(neg_log_dens(pos)): break init_states.append(pos) else: init_states = [raveled_initial_point.data.copy() for c in range(chains)] _, traces, stats = sampler.sample_chains( n_warm_up_iter=tune, n_main_iter=draws, init_states=init_states, adapters=[step_size_adapter, metric_adapter], trace_funcs=[trace_func], n_process=cores, display_progress=progressbar, monitor_stats=["accept_stat", "n_step", "diverging"], ) if return_inferencedata: return convert_to_inference_data(traces, stats) else: return {k: np.stack(v) for k, v in traces.items()}
[docs]def get_stan_model_unconstrained_param_dim(model: stan.Model) -> int: """Get total dimension of unconstrained parameters in Stan model. Args: model: Stan model to get dimension for. Returns: Non-negative integer specifying unconstrained parameter dimension. """ param_size_list = [np.prod(dim, dtype=np.int64) for dim in model.dims] n_dim = sum(param_size_list) while True: try: model.log_prob([0] * n_dim) return n_dim except RuntimeError: param_size_list.pop() n_dim = sum(param_size_list)
[docs]def construct_stan_model_functions( model: stan.Model, ) -> tuple[ScalarFunction, GradientFunction, TraceFunction]: """Construct functions for sampling from Stan model using Mici. Args: model: Stan model to construct functions for. Returns: Tuple :code:`(neg_log_dens, grad_neg_log_dens, trace_func)` with :code:`neg_log_dens` a function for evaluating negative logarithm of unnormalized posterior density associated with model, :code:`grad_neg_log_dens` a function for evaluating gradient of :code:`neg_log_dens` with respect to position array argument and :code:`trace_func` a function which extract model parameter values from chain state for tracing during sampling. """ def neg_log_dens(u): return -model.log_prob(list(u)) def grad_neg_log_dens(u): return -np.array(model.grad_log_prob(list(u))) param_size_list = [np.prod(dim, dtype=np.int64) for dim in model.dims] def trace_func(state): param_array = np.array(model.constrain_pars(list(state.pos))) trace_dict = { name: val.reshape(shape) for name, val, shape in zip( model.param_names, np.split(param_array, np.cumsum(param_size_list)[:-1]), model.dims, ) } trace_dict["lp"] = -neg_log_dens(state.pos) return trace_dict return neg_log_dens, grad_neg_log_dens, trace_func
[docs]def sample_stan_model( model_code: str, data: dict, *, num_samples: int = 1000, num_warmup: int = 1000, num_chains: int = 4, save_warmup: bool = False, metric: Literal["unit_e", "diag_e", "dense_e"] = "diag_e", stepsize: float = 1.0, adapt_engaged: bool = True, delta: float = 0.8, gamma: float = 0.05, kappa: float = 0.75, t0: int = 10, init_buffer: int = 75, term_buffer: int = 50, window: int = 25, max_depth: int = 10, seed: Optional[int] = None, return_inferencedata: bool = False, ) -> Union[arviz.InferenceData, dict[str, ArrayLike]]: """Generate approximate samples from posterior defined by a Stan model. Uses dynamic multinomial HMC algorithm in Mici with adaptive warm-up phase. This function follows a similar argument naming scheme to the PyStan :py:meth:`stan.model.Model.sample` method (which itself follows CmdStan) to allow using as a (partial) drop-in replacement. Args: model_code: Stan program code describing a Stan model. data: A Python dictionary or mapping providing the data for the model. Variable names are the keys and the values are their associated values. num_samples: A non-negative integer specifying the number of non-warm-up iterations per chain. num_warmup: A non-negative integer specifying the number of warm-up iterations per chain. num_chains: A positive integer specifying the number of Markov chains. save_warmup: Whether to save warm-up chain data (`True`) or not (`False`). metric: String specifying metric type. One of "unit_e", "diag_e" or "dense_e", indicating respectively to used a fixed identity matrix metric representation, to use a diagonal metric matrix representation adapted based on estimates of the marginal posterior variances, to use a dense metric matrix representation based on estimates of the posterior covariance matrix. stepsize: Initial integrator step size. adapt_engaged: Whether adaptation is engaged (`True`) or not (`False`). delta: Adaptation target acceptance statistic. gamma: Adaptation regularization scale. kappa: Adaptation relaxation exponent. t0: Adaptation iteration offset. init_buffer: Width of initial fast adaptation interval. term_buffer: Width of final fast adaptation interval. window: Initial width of slow adaptation interval. max_depth: Maximum depth of binary trajectory tree. seed: Seed for Numpy random number generator used for generating random variables while sampling chains. If `None` then generator will be seeded with entropy from operating system. return_inferencedata: Whether to return the traces as an `arviz.InferenceData` (`True`) object or a dict (`False`). Returns: A dictionary or ArviZ `InferenceData` object containing the sampled chain output. Dictionary output (when `return_inferencedata=False`) has string keys corresponding to the name of each traced variable in the model, with the values being the corresponding values of the variables traced across the chains as NumPy arrays, with the first dimension the flattened draw index across all chains (of size equal to `num_chains * num_samples`) and any remaining dimensions corresponding to the dimensions of the traced variable. If `return_inferencedata=True` an ArviZ `InferenceData` object is instead returned with traced chain data stored in the `posterior` group and additional chain statistics in the `sample_stats` group. """ import stan if return_inferencedata and importlib.util.find_spec("arviz") is None: msg = "Cannot return InferenceData as ArviZ is not installed" raise ValueError(msg) model = stan.build(model_code, data=data) neg_log_dens, grad_neg_log_dens, trace_func = construct_stan_model_functions(model) system = mici.systems.EuclideanMetricSystem( neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens, ) integrator = mici.integrators.LeapfrogIntegrator(system, step_size=stepsize) rng = np.random.default_rng(seed) sampler = mici.samplers.DynamicMultinomialHMC( system, integrator, rng, max_tree_depth=max_depth, ) if adapt_engaged: step_size_adapter = mici.adapters.DualAveragingStepSizeAdapter( adapt_stat_target=delta, iter_offset=t0, iter_decay_coeff=kappa, log_step_size_reg_coefficient=gamma, ) adapters = [step_size_adapter] if metric == "diag_e": adapters.append(mici.adapters.OnlineVarianceMetricAdapter()) elif metric == "dense_e": adapters.append(mici.adapters.OnlineCovarianceMetricAdapter()) if len(adapters) > 1: stager = mici.stagers.WindowedWarmUpStager( n_init_fast_stage_iter=init_buffer, n_final_fast_stage_iter=term_buffer, n_init_slow_window_iter=window, ) else: stager = mici.stagers.WarmUpStager() else: adapters = None stager = None dim_u = get_stan_model_unconstrained_param_dim(model) init_states = rng.uniform(-2, 2, size=(num_chains, dim_u)) _, traces, stats = sampler.sample_chains( n_warm_up_iter=num_warmup, n_main_iter=num_samples, init_states=init_states, adapters=adapters, stager=stager, trace_funcs=[trace_func], monitor_stats=["accept_stat", "n_step", "diverging"], trace_warm_up=save_warmup, ) if return_inferencedata: return convert_to_inference_data(traces, stats) else: return {k: np.concatenate(v).swapaxes(0, -1) for k, v in traces.items()}