Source code for mici.adapters

"""Methods for adaptively setting algorithmic parameters of transitions."""

from __future__ import annotations

from abc import ABC, abstractmethod
from math import exp, log
from typing import TYPE_CHECKING

import numpy as np

from mici.errors import AdaptationError, IntegratorError
from mici.matrices import DensePositiveDefiniteMatrix, PositiveDiagonalMatrix

if TYPE_CHECKING:
    from collections.abc import Collection, Iterable

    from numpy.random import Generator
    from numpy.typing import ArrayLike

    from mici.integrators import Integrator
    from mici.states import ChainState
    from mici.systems import System
    from mici.transitions import Transition
    from mici.types import (
        AdaptationStatisticFunction,
        AdapterState,
        ReducerFunction,
        TransitionStatistics,
    )


[docs] class Adapter(ABC): """Abstract adapter for implementing schemes to adapt transition parameters. Adaptation schemes are assumed to be based on updating a collection of adaptation variables (collectively termed the adapter state here) after each chain transition based on the sampled chain state and/or statistics of the transition such as an acceptance probability statistic. After completing a chain of one or more adaptive transitions, the final adapter state may be used to perform a final update to the transition parameters. """
[docs] @abstractmethod def initialize( self, chain_state: ChainState, transition: Transition, ) -> AdapterState: """Initialize adapter state prior to starting adaptive transitions. Args: chain_state: Initial chain state adaptive transition will be started from. May be used to calculate initial adapter state but should not be mutated by method. transition: Markov transition being adapted. Attributes of the transition or child objects may be updated in-place by the method. Returns: Initial adapter state. """
[docs] @abstractmethod def update( self, adapt_state: AdapterState, chain_state: ChainState, trans_stats: TransitionStatistics, transition: Transition, ) -> None: """Update adapter state after sampling from transition being adapted. Args: adapt_state: Current adapter state. Entries will be updated in-place by the method. chain_state: Current chain state following sampling from transition being adapted. May be used to calculate adapter state updates but should not be mutated by method. trans_stats: Dictionary of statistics associated with transition being adapted. May be used to calculate adapter state updates but should not be mutated by method. transition: Markov transition being adapted. Attributes of the transition or child objects may be updated in-place by the method. """
[docs] @abstractmethod def finalize( self, adapt_states: AdapterState | Iterable[AdapterState], chain_states: ChainState | Iterable[ChainState], transition: Transition, rngs: Generator | Iterable[Generator], ) -> None: """Update transition parameters based on final adapter state or states. Optionally, if multiple adapter states are available, e.g. from a set of independent adaptive chains, then these adaptation information from all the chains may be combined to set the transition parameter(s). Args: adapt_states: Final adapter state or a list of per chain adapter states. Arrays / buffers associated with the adapter state entries may be recycled to reduce memory usage - if so the corresponding entries will be removed from the adapter state dictionary / dictionaries. chain_states: Final state of chain (or states of chains) in current sampling stage. May be updated in-place if transition parameters altered by adapter require updating any state components. transition: Markov transition being dapted. Attributes of the transition or child objects will be updated in-place by the method. rngs: Random number generator for the chain or a list of per-chain random number generators. Used to resample any components of states needing to be updated due to adaptation if required. """
@property @abstractmethod def is_fast(self) -> bool: """Whether the adapter is 'fast' or 'slow'. An adapter which requires only local information to adapt the transition parameters should be classified as fast while one which requires more global information and so more chain iterations should be classified as slow i.e. :code:`is_fast == False`. """
[docs] def arithmetic_mean_log_step_size_reducer(log_step_sizes: Collection[float]) -> float: """Compute arithmetic mean of step sizes from their logs. Args: log_step_sizes: Logarithms of per-chain estimated step sizes. Returns: Arithmetic mean of estimated step sizes. """ return sum(exp(x) for x in log_step_sizes) / len(log_step_sizes)
[docs] def geometric_mean_log_step_size_reducer(log_step_sizes: Collection[float]) -> float: """Compute geometric mean of step sizes from their logs. Args: log_step_sizes: Logarithms of per-chain estimated step sizes. Returns: Geometric mean of estimated step sizes. """ return exp(sum(x for x in log_step_sizes) / len(log_step_sizes))
[docs] def min_log_step_size_reducer(log_step_sizes: Collection[float]) -> float: """Compute minimum of step sizes from their logs. Args: log_step_sizes: Logarithms of per-chain estimated step sizes. Returns: Minimum of estimated step sizes. """ return exp(min(log_step_sizes))
[docs] def default_adapt_stat_func(stats: TransitionStatistics) -> float: """Function to extract default statistic used for step-size adaptation. Args: stats: Dictionary of transition statistics. Returns: Acceptance statistic. """ return stats["accept_stat"]
[docs] class DualAveragingStepSizeAdapter(Adapter): """Dual averaging integrator step size adapter. Implementation of the dual algorithm step size adaptation algorithm described in Hoffman and Gelman (2014), a modified version of the stochastic optimisation scheme of Nesterov (2009). By default the adaptation is performed to control the :code:`accept_stat` statistic of an integration transition to be close to a target value but the statistic adapted on can be altered by changing the :code:`adapt_stat_func`. References: 1. Hoffman, M.D. and Gelman, A. (2014). The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), pp.1593-1623. 2. Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming 120(1), pp.221-259. """ is_fast = True def __init__( self, adapt_stat_target: float = 0.8, adapt_stat_func: AdaptationStatisticFunction | None = None, log_step_size_reg_target: float | None = None, log_step_size_reg_coefficient: float = 0.05, iter_decay_coeff: float = 0.75, iter_offset: int = 10, max_init_step_size_iters: int = 100, log_step_size_reducer: ReducerFunction | None = None, ) -> None: """ Args: adapt_stat_target: Target value for the transition statistic being controlled during adaptation. adapt_stat_func: Function which given a dictionary of transition statistics outputs the value of the statistic to control during adaptation. By default this is set to a function which simply selects the :code:'accept_stat' value in the statistics dictionary. log_step_size_reg_target: Value to regularize the controlled output (logarithm of the integrator step size) towards. If :code:`None` set to :code:`log(10 * init_step_size)` where :code:`init_step_size` is the initial 'reasonable' step size found by a coarse search as recommended in Hoffman and Gelman (2014). This has the effect of giving the dual averaging algorithm a tendency towards testing step sizes larger than the initial value, with typically integrating with a larger step size having a lower computational cost. log_step_size_reg_coefficient: Coefficient controlling amount of regularisation of controlled output (logarithm of the integrator step size) towards :code:`log_step_size_reg_target`. Defaults to 0.05 as recommended in Hoffman and Gelman (2014). iter_decay_coeff: Coefficient controlling exponent of decay in schedule weighting stochastic updates to smoothed log step size estimate. Should be in the interval (0.5, 1] to ensure asymptotic convergence of adaptation. A value of 1 gives equal weight to the whole history of updates while setting to a smaller value increasingly highly weights recent updates, giving a tendency to 'forget' early updates. Defaults to 0.75 as recommended in Hoffman and Gelman (2014). iter_offset: Offset used for the iteration based weighting of the adaptation statistic error estimate. Should be set to a non-negative value. A value > 0 has the effect of stabilising early iterations. Defaults to the value of 10 as recommended in Hoffman and Gelman (2014). max_init_step_size_iters: Maximum number of iterations to use in initial search for a reasonable step size with an :py:exc:`mici.errors.AdaptationError` exception raised if a suitable step size is not found within this many iterations. log_step_size_reducer: Reduction to apply to final per-chain step sizes estimates to produce overall integrator step size for main chain stages. The specified function should accept a sequence of logarithms of estimated step sizes and output a non-negative step size to use. If :code:`None`, the default, a function which computes the arithmetic mean of the per-chain step sizes is used. """ self.adapt_stat_target = adapt_stat_target self.adapt_stat_func = ( default_adapt_stat_func if adapt_stat_func is None else adapt_stat_func ) self.log_step_size_reg_target = log_step_size_reg_target self.log_step_size_reg_coefficient = log_step_size_reg_coefficient self.iter_decay_coeff = iter_decay_coeff self.iter_offset = iter_offset self.max_init_step_size_iters = max_init_step_size_iters self.log_step_size_reducer = ( arithmetic_mean_log_step_size_reducer if log_step_size_reducer is None else log_step_size_reducer )
[docs] def initialize( self, chain_state: ChainState, transition: Transition, ) -> AdapterState: integrator = transition.integrator system = transition.system adapt_state = { "iter": 0, "smoothed_log_step_size": 0.0, "adapt_stat_error": 0.0, } init_step_size = self._find_and_set_init_step_size( chain_state, system, integrator, ) if self.log_step_size_reg_target is None: adapt_state["log_step_size_reg_target"] = log(10 * init_step_size) else: adapt_state["log_step_size_reg_target"] = self.log_step_size_reg_target return adapt_state
def _find_and_set_init_step_size( self, state: ChainState, system: System, integrator: Integrator, ) -> float: """Find initial step size by coarse search using single step statistics. Adaptation of Algorithm 4 in Hoffman and Gelman (2014). Compared to the Hoffman and Gelman algorithm, this version makes two changes: 1. The absolute value of the change in Hamiltonian over a step being larger or smaller than log(2) is used to determine whether the step size is too big or small as opposed to the value of the equivalent Metropolis accept probability being larger or smaller than 0.5. Although a negative change in the Hamiltonian over a step of magnitude more than log(2) will lead to an accept probability of 1 for the forward move, the corresponding reversed move will have an accept probability less than 0.5, and so a change in the Hamiltonian over a step of magnitude more than log(2) irrespective of the sign of the change is indicative of the minimum acceptance probability over both forward and reversed steps being less than 0.5. 2. To allow for integrators for which an integrator step may fail due to e.g. a convergence error in an iterative solver, the step size is also considered to be too big if any of the step sizes tried in the search result in a failed integrator step, with in this case the step size always being decreased on subsequent steps irrespective of the initial Hamiltonian error, until a integrator step successfully completes and the absolute value of the change in Hamiltonian is below the threshold of log(2) (corresponding to a minimum acceptance probability over forward and reversed steps of 0.5). """ init_state = state.copy() h_init = system.h(init_state) if np.isnan(h_init): msg = "Hamiltonian evaluating to NaN at initial state." raise AdaptationError(msg) integrator.step_size = 1 delta_h_threshold = log(2) for s in range(self.max_init_step_size_iters): try: state = integrator.step(init_state) delta_h = abs(h_init - system.h(state)) if s == 0 or np.isnan(delta_h): step_size_too_big = np.isnan(delta_h) or delta_h > delta_h_threshold if (step_size_too_big and delta_h <= delta_h_threshold) or ( not step_size_too_big and delta_h > delta_h_threshold ): return integrator.step_size if step_size_too_big: integrator.step_size /= 2 else: integrator.step_size *= 2 except IntegratorError: # noqa: PERF203 step_size_too_big = True integrator.step_size /= 2 msg = ( f"Could not find reasonable initial step size in " f"{self.max_init_step_size_iters} iterations (final step size " f"{integrator.step_size}). A very large final step size may indicate that " f"the target distribution is improper such that the negative log density " f"is flat in one or more directions while a very small final step size may " f"indicate that the density function is insufficiently smooth at the point " f"initialized at." ) raise AdaptationError(msg)
[docs] def update( self, adapt_state: AdapterState, chain_state: ChainState, # noqa: ARG002 trans_stats: TransitionStatistics, transition: Transition, ) -> None: adapt_state["iter"] += 1 error_weight = 1 / (self.iter_offset + adapt_state["iter"]) adapt_state["adapt_stat_error"] *= 1 - error_weight adapt_state["adapt_stat_error"] += error_weight * ( self.adapt_stat_target - self.adapt_stat_func(trans_stats) ) smoothing_weight = (1 / adapt_state["iter"]) ** self.iter_decay_coeff log_step_size = adapt_state["log_step_size_reg_target"] - ( adapt_state["adapt_stat_error"] * adapt_state["iter"] ** 0.5 / self.log_step_size_reg_coefficient ) adapt_state["smoothed_log_step_size"] *= 1 - smoothing_weight adapt_state["smoothed_log_step_size"] += smoothing_weight * log_step_size transition.integrator.step_size = exp(log_step_size)
[docs] def finalize( self, adapt_states: AdapterState | Iterable[AdapterState], chain_states: ChainState | Iterable[ChainState], # noqa: ARG002 transition: Transition, rngs: Generator | Iterable[Generator], # noqa: ARG002 ) -> None: if isinstance(adapt_states, dict): transition.integrator.step_size = exp( adapt_states["smoothed_log_step_size"], ) else: transition.integrator.step_size = self.log_step_size_reducer( [adapt_state["smoothed_log_step_size"] for adapt_state in adapt_states], )
[docs] class OnlineVarianceMetricAdapter(Adapter): """Diagonal metric adapter using online variance estimates. Uses Welford's algorithm (Welford, 1962) to stably compute an online estimate of the sample variances of the chain state position components during sampling. If online estimates are available from multiple independent chains, the final variance estimate is calculated from the per-chain statistics using the parallel / batched incremental variance algorithm described by Chan et al. (1979). The variance estimates are optionally regularized towards a common scalar value, with increasing weight for small number of samples, to decrease the effect of noisy estimates for small sample sizes, following the approach in Stan (Carpenter et al., 2017). The metric matrix representation is set to a diagonal matrix with diagonal elements corresponding to the reciprocal of the (regularized) variance estimates. References: 1. Welford, B. P. (1962). Note on a method for calculating corrected sums of squares and products. Technometrics, 4(3), pp. 419-420. 2. Chan, T. F., Golub, G. H. and LeVeque, R. J. (1979). Updating formulae and a pairwise algorithm for computing sample variances. Technical Report STAN-CS-79-773, Department of Computer Science, Stanford University. 3. Carpenter, B., Gelman, A., Hoffman, M.D., Lee, D., Goodrich, B., Betancourt, M., Brubaker, M., Guo, J., Li, P. and Riddell, A. (2017). Stan: A probabilistic programming language. Journal of Statistical Software, 76(1). """ is_fast = False def __init__(self, reg_iter_offset: int = 5, reg_scale: float = 1e-3) -> None: """ Args: reg_iter_offset: Iteration offset used for calculating iteration dependent weighting between regularisation target and current covariance estimate. Higher values cause stronger regularisation during initial iterations. A value of zero corresponds to no regularisation; this should only be used if the sample covariance is guaranteed to be positive definite. reg_scale: Positive scalar defining value variance estimates are regularized towards. """ self.reg_iter_offset = reg_iter_offset self.reg_scale = reg_scale
[docs] def initialize( self, chain_state: ChainState, transition: Transition, # noqa: ARG002 ) -> AdapterState: return { "iter": 0, "mean": np.zeros_like(chain_state.pos), "sum_diff_sq": np.zeros_like(chain_state.pos), }
[docs] def update( self, adapt_state: AdapterState, chain_state: ChainState, trans_stats: TransitionStatistics, # noqa: ARG002 transition: Transition, # noqa: ARG002 ) -> None: # Use Welford (1962) incremental algorithm to update statistics to # calculate online variance estimate # https://en.wikipedia.org/wiki/ # Algorithms_for_calculating_variance#Welford's_online_algorithm adapt_state["iter"] += 1 pos_minus_mean = chain_state.pos - adapt_state["mean"] adapt_state["mean"] += pos_minus_mean / adapt_state["iter"] adapt_state["sum_diff_sq"] += pos_minus_mean * ( chain_state.pos - adapt_state["mean"] )
def _regularize_var_est(self, var_est: ArrayLike, n_iter: int) -> None: """Update variance estimates by regularizing towards common scalar. Performed in place to prevent further array allocations. """ if self.reg_iter_offset is not None and self.reg_iter_offset != 0: var_est *= n_iter / (self.reg_iter_offset + n_iter) var_est += self.reg_scale * ( self.reg_iter_offset / (self.reg_iter_offset + n_iter) )
[docs] def finalize( self, adapt_states: AdapterState | Iterable[AdapterState], chain_states: ChainState | Iterable[ChainState], transition: Transition, rngs: Generator | Iterable[Generator], ) -> None: if isinstance(adapt_states, dict): n_iter = adapt_states["iter"] var_est = adapt_states.pop("sum_diff_sq") chain_states = [chain_states] rngs = [rngs] else: # Use Chan et al. (1979) parallel variance estimation algorithm # to combine per-chain statistics # https://en.wikipedia.org/wiki/ # Algorithms_for_calculating_variance#Parallel_algorithm for i, adapt_state in enumerate(adapt_states): if i == 0: n_iter = adapt_state["iter"] mean_est = adapt_state.pop("mean") var_est = adapt_state.pop("sum_diff_sq") else: n_iter_prev = n_iter n_iter += adapt_state["iter"] mean_diff = mean_est - adapt_state["mean"] mean_est *= n_iter_prev mean_est += adapt_state["iter"] * adapt_state["mean"] mean_est /= n_iter var_est += adapt_state["sum_diff_sq"] var_est += ( mean_diff**2 * (adapt_state["iter"] * n_iter_prev) / n_iter ) if n_iter < 2: # noqa: PLR2004 msg = "At least two chain samples required to compute a variance estimates." raise AdaptationError(msg) var_est /= n_iter - 1 self._regularize_var_est(var_est, n_iter) transition.system.metric = PositiveDiagonalMatrix(var_est).inv # Resample momentum to account for altered distribution due to new metric for chain_state, rng in zip(chain_states, rngs, strict=True): chain_state.mom = transition.system.sample_momentum(chain_state, rng)
[docs] class OnlineCovarianceMetricAdapter(Adapter): """Dense metric adapter using online covariance estimates. Uses Welford's algorithm (Welford, 1962) to stably compute an online estimate of the sample covariance matrix of the chain state position components during sampling. If online estimates are available from multiple independent chains, the final covariance matrix estimate is calculated from the per-chain statistics using a covariance variant due to Schubert and Gertz (2018) of the parallel / batched incremental variance algorithm described by Chan et al. (1979). The covariance matrix estimates are optionally regularized towards a scaled identity matrix, with increasing weight for small number of samples, to decrease the effect of noisy estimates for small sample sizes, following the approach in Stan (Carpenter et al., 2017). The metric matrix representation is set to a dense positive definite matrix corresponding to the inverse of the (regularized) covariance matrix estimate. References: 1. Welford, B. P. (1962). Note on a method for calculating corrected sums of squares and products. Technometrics, 4(3), pp. 419-420. 2. Schubert, E. and Gertz, M. (2018). Numerically stable parallel computation of (co-)variance. ACM. p. 10. doi:10.1145/3221269.3223036. 3. Chan, T. F., Golub, G. H. and LeVeque, R. J. (1979). Updating formulae and a pairwise algorithm for computing sample variances. Technical Report STAN-CS-79-773, Department of Computer Science, Stanford University. 4. Carpenter, B., Gelman, A., Hoffman, M.D., Lee, D., Goodrich, B., Betancourt, M., Brubaker, M., Guo, J., Li, P. and Riddell, A. (2017). Stan: A probabilistic programming language. Journal of Statistical Software, 76(1). """ is_fast = False def __init__(self, reg_iter_offset: int = 5, reg_scale: float = 1e-3) -> None: """ Args: reg_iter_offset: Iteration offset used for calculating iteration dependent weighting between regularisation target and current covariance estimate. Higher values cause stronger regularisation during initial iterations. reg_scale: Positive scalar defining value variance estimates are regularized towards. """ self.reg_iter_offset = reg_iter_offset self.reg_scale = reg_scale
[docs] def initialize( self, chain_state: ChainState, transition: Transition, # noqa: ARG002 ) -> AdapterState: dim_pos = chain_state.pos.shape[0] dtype = chain_state.pos.dtype return { "iter": 0, "mean": np.zeros(shape=(dim_pos,), dtype=dtype), "sum_diff_outer": np.zeros(shape=(dim_pos, dim_pos), dtype=dtype), }
[docs] def update( self, adapt_state: AdapterState, chain_state: ChainState, trans_stats: TransitionStatistics, # noqa: ARG002 transition: Transition, # noqa: ARG002 ) -> None: # Use Welford (1962) incremental algorithm to update statistics to # calculate online covariance estimate # https://en.wikipedia.org/wiki/ # Algorithms_for_calculating_variance#Online adapt_state["iter"] += 1 pos_minus_mean = chain_state.pos - adapt_state["mean"] adapt_state["mean"] += pos_minus_mean / adapt_state["iter"] adapt_state["sum_diff_outer"] += ( pos_minus_mean[None, :] * (chain_state.pos - adapt_state["mean"])[:, None] )
def _regularize_covar_est(self, covar_est: ArrayLike, n_iter: int) -> None: """Update covariance estimate by regularising towards identity. Performed in place to prevent further array allocations. """ covar_est *= n_iter / (self.reg_iter_offset + n_iter) covar_est_diagonal = np.einsum("ii->i", covar_est) covar_est_diagonal += self.reg_scale * ( self.reg_iter_offset / (self.reg_iter_offset + n_iter) )
[docs] def finalize( self, adapt_states: AdapterState | Iterable[AdapterState], chain_states: ChainState | Iterable[ChainState], transition: Transition, rngs: Generator | Iterable[Generator], ) -> None: if isinstance(adapt_states, dict): n_iter = adapt_states["iter"] covar_est = adapt_states.pop("sum_diff_outer") chain_states = [chain_states] rngs = [rngs] else: # Use Schubert and Gertz (2018) parallel covariance estimation # algorithm to combine per-chain statistics for i, adapt_state in enumerate(adapt_states): if i == 0: n_iter = adapt_state["iter"] mean_est = adapt_state.pop("mean") covar_est = adapt_state.pop("sum_diff_outer") else: n_iter_prev = n_iter n_iter += adapt_state["iter"] mean_diff = mean_est - adapt_state["mean"] mean_est *= n_iter_prev mean_est += adapt_state["iter"] * adapt_state["mean"] mean_est /= n_iter covar_est += adapt_state["sum_diff_outer"] covar_est += ( np.outer(mean_diff, mean_diff) * (adapt_state["iter"] * n_iter_prev) / n_iter ) if n_iter < 2: # noqa: PLR2004 msg = "At least two chain samples required to compute a variance estimates." raise AdaptationError(msg) covar_est /= n_iter - 1 self._regularize_covar_est(covar_est, n_iter) transition.system.metric = DensePositiveDefiniteMatrix(covar_est).inv # Resample momentum to account for altered distribution due to new metric for chain_state, rng in zip(chain_states, rngs, strict=True): chain_state.mom = transition.system.sample_momentum(chain_state, rng)