"""Markov transition kernels."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod, abstractproperty
from typing import TYPE_CHECKING, NamedTuple
import numpy as np
from mici.errors import (
ConvergenceError,
Error,
HamiltonianDivergenceError,
IntegratorError,
NonReversibleStepError,
)
from mici.utils import LogRepFloat
if TYPE_CHECKING:
from typing import Optional
from numpy.random import Generator
from numpy.typing import ArrayLike, DTypeLike
from mici.integrators import Integrator
from mici.states import ChainState
from mici.systems import System
from mici.types import ScalarLike, TerminationCriterion
logger = logging.getLogger(__name__)
def _process_integrator_error(exception: Error, stats: dict[str, ScalarLike]):
logger.info(f"Terminating trajectory due to error:\n{exception!s}")
# Only set stats fields to True if exception is of matching type.
# Corresponding fields should be set to False by default for transitions
# which potentially raise these errors.
if isinstance(exception, HamiltonianDivergenceError):
stats["diverging"] = True
elif isinstance(exception, NonReversibleStepError):
stats["non_reversible_step"] = True
elif isinstance(exception, ConvergenceError):
stats["convergence_error"] = True
[docs]class Transition(ABC):
"""Base class for Markov transition kernels.
Defines expected interface for transitions by sampler classes.
"""
@abstractproperty
def state_variables(self) -> set[str]:
"""A set of names of state variables accessed by this transition."""
@property
def statistic_types(self) -> Optional[dict[str, tuple[DTypeLike, ScalarLike]]]:
"""A dictionary describing the statistics computed during transition.
Either :code:`None` if no statistics are returned by :py:meth:`sample` method or
a dictionary with string keys and tuple values, with the keys defining the keys
of the statistics returned in the :code:`trans_stats` return value of the
:py:meth:`sample` method and the first entry of the value tuples an appropriate
NumPy `dtype` for the array used to store the corresponding statistic values and
second entry the default value to initialize this array with.
"""
return None
[docs] @abstractmethod
def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
"""Sample a new chain state from the Markov transition kernel.
Args:
state: Current chain state to condition transition kernel on.
rng: Numpy random number generator.
Returns:
Tuple of updated state object and any statistics computed during the
transition or :code:`None` if no statistics.
"""
[docs]class MomentumTransition(Transition):
"""Base class for momentum transitions.
Markov transition kernel which leaves the conditional distribution on the momentum
under the canonical distribution invariant, updating only the momentum component of
the chain state.
"""
@property
def state_variables(self) -> set[str]:
return {"mom"}
def __init__(self, system: System):
"""
Args:
system: Hamiltonian system defining conditional distribution on momentum to
leave invariant.
"""
self.system = system
[docs] @abstractmethod
def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
"""Sample a new momentum component to state.
Assigns a new momentum component to state by sampling from a Markov transition
kernel which leaves the conditional distribution on the momentum under the
canonical distribution defined by the Hamiltonian system invariant.
Args:
state: Current chain state to condition transition kernel on.
rng: Numpy random number generator.
Returns:
Tuple of updated state object and any statistics computed during the
transition or `None` if no statistics.
"""
[docs]class IndependentMomentumTransition(MomentumTransition):
"""Independent momentum transition.
Independently resamples the momentum component of the state from its conditional
distribution given the remaining state.
"""
[docs] def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
state.mom = self.system.sample_momentum(state, rng)
return state, None
[docs]class IntegrationTransition(Transition):
"""Base class for integration transtions.
Markov transition kernel which leaves canonical distribution invariant and jointly
updates the position and momentum components of the chain state by integrating the
Hamiltonian dynamics of the system to propose new values for the state.
"""
@property
def state_variables(self) -> set[str]:
return {"pos", "mom", "dir"}
@property
def statistic_types(self) -> dict[str, tuple[DTypeLike, ScalarLike]]:
return self._statistic_types
def __init__(self, system: System, integrator: Integrator):
"""
Args:
system: Hamiltonian system to be simulated.
integrator: Symplectic integrator appropriate to the specified Hamiltonian
system.
"""
self.system = system
self.integrator = integrator
self._statistic_types = {
"n_step": (np.int64, -1),
"accept_stat": (np.float64, np.nan),
"non_reversible_step": (bool, False),
"convergence_error": (bool, False),
"step_size": (np.float64, np.nan),
}
[docs] @abstractmethod
def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
"""Sample a position-momentum pair using integration based proposal(s).
Samples new position and momentum values from a Markov transition kernel which
leaves the canonical distribution on the state space corresponding to the
Hamiltonian system invariant.
Args:
state: Current chain state.
rng: Numpy random number generator.
Returns:
Tuple of updated state object and any statistics computed during the
transition or `None` if no statistics.
"""
[docs]class MetropolisIntegrationTransition(IntegrationTransition):
"""Base for HMC methods using a Metropolis accept step to sample new state.
In each transition a trajectory is generated by integrating the Hamiltonian dynamics
from the current state in the current integration time direction for a number of
integrator steps.
The state at the end of the trajectory with the integration direction negated (this
ensuring the proposed move is an involution) is used as the proposal in a Metropolis
acceptance step. The integration direction is then deterministically negated again
irrespective of the accept decision, with the effect being that on acceptance the
integration direction will be equal to its initial value and on rejection the
integration direction will be the negation of its initial value.
"""
def __init__(self, system: System, integrator: Integrator):
super().__init__(system, integrator)
self._statistic_types["metrop_accept_prob"] = (np.float64, np.nan)
def _sample_n_step(
self,
state: ChainState,
n_step: int,
rng: Generator,
) -> tuple[ChainState, dict[str, ScalarLike]]:
h_init = self.system.h(state)
state_p = state
integration_error = False
stats = {
"convergence_error": False,
"non_reversible_step": False,
"step_size": self.integrator.step_size,
}
try:
for _s in range(n_step):
state_p = self.integrator.step(state_p)
except IntegratorError as e:
integration_error = True
stats["n_step"] = _s
_process_integrator_error(e, stats)
else:
stats["n_step"] = n_step
# Reverse integration direction of proposal to form an involution
state_p.dir *= -1
if state_p is not state:
h_final = self.system.h(state_p)
h_diff = h_init - h_final
# Explicitly check if h_diff is NaN as min(0, NaN) = 0
accept_prob = 0.0 if np.isnan(h_diff) else np.exp(min(0, h_diff))
else:
accept_prob = 0.0
stats["metrop_accept_prob"] = accept_prob
stats["accept_stat"] = accept_prob if not integration_error else 0.0
if not integration_error and rng.uniform() < accept_prob:
state = state_p
# Reverse integration direction of new state
# As extended target distribution is symmetric in direction indicator
# this always leaves the distribution invariant
state.dir *= -1
return state, stats
[docs]class MetropolisStaticIntegrationTransition(MetropolisIntegrationTransition):
"""Static integration transition with Metropolis sampling of new state.
In this variant the trajectory is generated by integrating the state through time a
fixed number of integrator steps. This is original proposed Hybrid Monte Carlo
(often now instead termed Hamiltonian Monte Carlo) algorithm (Duane et al., 1987;
Neal, 2011).
References:
1. Duane, S., Kennedy, A.D., Pendleton, B.J. and Roweth, D. (1987). Hybrid Monte
Carlo. Physics letters B, 195(2), pp.216-222.
2. Neal, R.M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov Chain
Monte Carlo, 2(11), p.2.
"""
def __init__(self, system: System, integrator: Integrator, n_step: int):
"""
Args:
system: Hamiltonian system to be simulated.
integrator: Symplectic integrator appropriate to the specified Hamiltonian
system.
n_step: Number of integrator steps to simulate in each transition.
"""
super().__init__(system, integrator)
if n_step <= 0:
msg = "Number of integrator steps must be positive."
raise ValueError(msg)
self.n_step = n_step
[docs] def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
return self._sample_n_step(state, self.n_step, rng)
[docs]class MetropolisRandomIntegrationTransition(MetropolisIntegrationTransition):
"""Random integration transition with Metropolis sampling of new state.
In each transition a trajectory is generated by integrating the state in the current
integration direction in time a random integer number of integrator steps sampled
from the uniform distribution on an integer interval. The randomisation of the
number of integration steps avoids the potential of the chain mixing poorly due to
using an integration time close to the period of (near) periodic systems
(Neal, 2011; Mackenzie, 1989).
References:
1. Neal, R.M. (2011). MCMC using Hamiltonian dynamics. Handbook of Markov Chain
Monte Carlo, 2(11), p.2.
2. Mackenzie, P.B. (1989). An improved hybrid Monte Carlo method. Physics Letters
B, 226(3-4), pp.369-371.
"""
def __init__(
self,
system: System,
integrator: Integrator,
n_step_range: tuple[int, int],
):
"""
Args:
system: Hamiltonian system to be simulated.
integrator: Symplectic integrator appropriate to the specified Hamiltonian
system.
n_step_range: Tuple :code:`(lower, upper)` with two positive integer entries
:code:`lower` and :code:`upper` (with :code:`upper > lower`) specifying
respectively the lower and upper bounds (inclusive) of integer interval
to uniformly draw random number integrator steps to simulate in each
transition.
"""
super().__init__(system, integrator)
n_step_lower, n_step_upper = n_step_range
if not (n_step_lower > 0 and n_step_lower < n_step_upper):
msg = "Range bounds must be non-negative and first entry less than last."
raise ValueError(msg)
self.n_step_range = n_step_range
[docs] def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, Optional[dict[str, ScalarLike]]]:
n_step = rng.integers(*self.n_step_range)
return self._sample_n_step(state, n_step, rng)
[docs]def euclidean_no_u_turn_criterion(
system: System,
state_1: ChainState,
state_2: ChainState,
_sum_mom: ArrayLike,
) -> bool:
"""No-U-turn termination criterion for Euclidean manifolds.
Terminates trajectories when the velocities at the terminal states of the trajectory
both have negative dot products with the vector from the position of the first
terminal state to the position of the second terminal state, corresponding to
further evolution of the trajectory reducing the distance between the terminal state
positions.
Args:
system: Hamiltonian system being integrated.
state_1: First terminal state of trajectory.
state_2: Second terminal state of trajectory.
_sum_mom: Sum of momentums of trajectory states (unused).
Returns:
Whether termination criterion is satisfied.
References:
1. Hoffman, M.D. and Gelman, A. (2014). The No-U-turn sampler: adaptively setting
path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research,
15(1), pp.1593-1623.
"""
return (
np.sum(system.dh_dmom(state_1) * (state_2.pos - state_1.pos)) < 0
or np.sum(system.dh_dmom(state_2) * (state_2.pos - state_1.pos)) < 0
)
[docs]def riemannian_no_u_turn_criterion(
system: System,
state_1: ChainState,
state_2: ChainState,
sum_mom: ArrayLike,
) -> bool:
"""Generalized no-U-turn termination criterion on Riemannian manifolds.
Terminates trajectories when the velocities at the terminal states of the trajectory
both have negative dot products with the sum of the the momentums across the
trajectory from the first to second terminal state of the first terminal state to
the position of the second terminal state (Betancourt, 2013). This generalizes the
no-U-turn criterion of Hoffman and Gelman (2014) to Riemannian manifolds where due
to the intrinsic curvature of the space the geodesic between two points is general
no longer a straight line.
Args:
system: Hamiltonian system being integrated.
state_1: First terminal state of trajectory.
state_2: Second terminal state of trajectory.
sum_mom: Sum of momentums of trajectory states.
Returns:
Whether termination criterion is satisfied.
References:
1. Hoffman, M.D. and Gelman, A. (2014). The No-U-turn sampler: adaptively setting
path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research,
15(1), pp.1593-1623.
2. Betancourt, M. (2013). Generalizing the no-U-turn sampler to Riemannian
manifolds. arXiv preprint arXiv:1304.1920.
"""
return (
np.sum(system.dh_dmom(state_1) * sum_mom) < 0
or np.sum(system.dh_dmom(state_2) * sum_mom) < 0
)
class _SubTree(NamedTuple):
"""Sub-tree of binary trajectory tree for dynamic integration transitions."""
negative: ChainState
positive: ChainState
sum_mom: ArrayLike
weight: ScalarLike
depth: int
[docs]class DynamicIntegrationTransition(IntegrationTransition):
"""Base class for dynamic integration transitions.
In each transition a binary tree of states is recursively computed by integrating
randomly forward and backward in time by a number of steps equal to the previous
tree size until a termination criteria on the tree's subtrees is met. The next chain
state is chosen from the candidate states using a progressive sampling scheme based
on relative weights of the different candidate states, with the sampling biased
towards states further from the current state.
References:
1. Hoffman, M.D. and Gelman, A., 2014. The No-U-turn sampler: adaptively setting
path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research,
15(1), pp.1593-1623.
2. Betancourt, M., 2017. A conceptual introduction to Hamiltonian Monte Carlo.
arXiv preprint arXiv:1701.02434.
"""
def __init__(
self,
system: System,
integrator: Integrator,
*,
max_tree_depth: int = 10,
max_delta_h: float = 1000.0,
termination_criterion: TerminationCriterion = riemannian_no_u_turn_criterion,
do_extra_subtree_checks: bool = True,
):
"""
Args:
system: Hamiltonian system to be simulated.
integrator: Symplectic integrator appropriate to the specified Hamiltonian
system.
max_tree_depth: Maximum depth to expand trajectory binary tree to. The
maximum number of integrator steps corresponds to
:code:`2**max_tree_depth`.
max_delta_h: Maximum change to tolerate in the Hamiltonian function over a
trajectory before signalling a divergence.
termination_criterion: Function computing criterion to use to determine when
to terminate trajectory tree expansion. The function should take a
Hamiltonian system as its first argument, a pair of states corresponding
to the two edge nodes in the trajectory (sub-)tree being checked and an
array containing the sum of the momentums over the trajectory
(sub)-tree. Defaults to :py:class:`riemannian_no_u_turn_criterion`.
do_extra_subtree_checks: Whether to perform additional termination criterion
checks on overlapping subtrees of the current tree to improve robustness
in systems with dynamics which are well approximated by independent
system of simple harmonic oscillators. In such systems (corresponding to
e.g. a standard normal target distribution and identity metric matrix
representation) at certain step sizes a 'resonant' behaviour is seen by
which the termination criterion fails to detect that the trajectory has
expanded past a half-period i.e. has 'U-turned' resulting in
trajectories continuing to expand, potentially up until the
:code:`max_tree_depth` limit is hit. For more details see `this Stan
Discourse discussion <kutt.it/yAkIES>`_. If
:code:`do_extra_subtree_checks` is set to :code:`True` additional
termination criterion checks are performed on overlapping subtrees which
help to reduce this resonant behaviour at the cost of more conservative
trajectory termination in some correlated models and some overhead from
additional checks.
"""
super().__init__(system, integrator)
if max_tree_depth <= 0:
msg = "max_tree_depth must be non-negative."
raise ValueError(msg)
self.max_tree_depth = max_tree_depth
self.max_delta_h = max_delta_h
self.termination_criterion = termination_criterion
self.do_extra_subtree_checks = do_extra_subtree_checks
self._statistic_types["av_metrop_accept_prob"] = (np.float64, np.nan)
self._statistic_types["reject_prob"] = (np.float64, np.nan)
self._statistic_types["tree_depth"] = (np.int64, -1)
self._statistic_types["diverging"] = (bool, False)
def _termination_criterion(
self,
tree: _SubTree,
neg_subtree: _SubTree,
pos_subtree: _SubTree,
) -> bool:
# If performing extra subtree checks evaluate lazily i.e. only evaluate
# if initial whole tree check fails. Extra subtree checks also only
# performed for trees of depth 2 and above (i.e. containing at least
# 4 states) as for trees of depth 1 they are redundant.
if self.termination_criterion(
self.system,
tree.negative,
tree.positive,
tree.sum_mom,
):
return True
elif tree.depth > 1 and self.do_extra_subtree_checks:
return self.termination_criterion(
self.system,
neg_subtree.negative,
pos_subtree.negative,
neg_subtree.sum_mom + pos_subtree.negative.mom,
) or self.termination_criterion(
self.system,
neg_subtree.positive,
pos_subtree.positive,
pos_subtree.sum_mom + neg_subtree.positive.mom,
)
return False
def _new_leave(
self,
state: ChainState,
h: ScalarLike,
aux_vars: dict[str, ArrayLike],
) -> _SubTree:
return _SubTree(
negative=state,
positive=state,
sum_mom=np.asarray(state.mom),
weight=self._weight_function(h, aux_vars),
depth=0,
)
def _merge_subtrees(self, neg_subtree: _SubTree, pos_subtree: _SubTree) -> _SubTree:
if neg_subtree.depth != pos_subtree.depth:
msg = "Cannot merge subtrees of different depths."
raise ValueError(msg)
return _SubTree(
negative=neg_subtree.negative,
positive=pos_subtree.positive,
weight=neg_subtree.weight + pos_subtree.weight,
sum_mom=neg_subtree.sum_mom + pos_subtree.sum_mom,
depth=neg_subtree.depth + 1,
)
def _init_aux_vars(
self,
state: ChainState,
rng: Generator,
) -> dict[str, ScalarLike]:
return {"h_init": self.system.h(state)}
@abstractmethod
def _weight_function(
self,
h: ScalarLike,
aux_vars: dict[str, ScalarLike],
) -> ScalarLike:
pass
@abstractmethod
def _weight_ratio(
self,
numerator: ScalarLike,
denominator: ScalarLike,
) -> ScalarLike:
pass
@abstractmethod
def _check_divergence(self, h: float, aux_vars: dict[str, ScalarLike]):
pass
def _build_tree(
self,
depth: int,
state: ChainState,
stats: dict[str, ScalarLike],
rng: Generator,
aux_vars: dict[str, ScalarLike],
) -> tuple[bool, Optional[_SubTree], Optional[ChainState]]:
if depth == 0:
# recursion base case
try:
# integrate forward/backward one step depending on state.dir
state = self.integrator.step(state)
h = self.system.h(state)
h = np.inf if np.isnan(h) else h
tree = self._new_leave(state, h, aux_vars)
proposal = state
h_diff = aux_vars["h_init"] - h
# Explicitly check if h_diff is NaN as min(0, NaN) = 0
metrop_accept_prob = 0.0 if np.isnan(h_diff) else np.exp(min(0, h_diff))
stats["sum_metrop_accept_prob"] += metrop_accept_prob
stats["n_step"] += 1
# default to assuming valid and then check for divergence
terminate = False
self._check_divergence(h, aux_vars)
except IntegratorError as e:
_process_integrator_error(e, stats)
terminate, tree, proposal = True, None, None
return terminate, tree, proposal
# build 'inner' subtree, i.e. starting from current state
terminate, inner_tree, inner_proposal = self._build_tree(
depth - 1,
state,
stats,
rng,
aux_vars,
)
if terminate:
return terminate, None, None
# build 'outer' subtree, i.e. starting from terminus of inner subtree
state = inner_tree.positive if state.dir == 1 else inner_tree.negative
terminate, outer_tree, outer_proposal = self._build_tree(
depth - 1,
state,
stats,
rng,
aux_vars,
)
if terminate:
return terminate, None, None
# merge two subtrees accounting for integration direction
neg_subtree = inner_tree if state.dir == 1 else outer_tree
pos_subtree = outer_tree if state.dir == 1 else inner_tree
tree = self._merge_subtrees(neg_subtree, pos_subtree)
# sample new proposal from two subtree proposals according to weights
accept_outer_prob = self._weight_ratio(outer_tree.weight, tree.weight)
proposal = (
outer_proposal if rng.uniform() < accept_outer_prob else inner_proposal
)
# check termination criterion on tree and subtrees
terminate = self._termination_criterion(tree, neg_subtree, pos_subtree)
return terminate, tree, proposal
[docs] def sample(
self,
state: ChainState,
rng: Generator,
) -> tuple[ChainState, dict[str, ScalarLike]]:
stats = {
"n_step": 0,
"sum_metrop_accept_prob": 0.0,
"reject_prob": 1.0,
"diverging": False,
"convergence_error": False,
"non_reversible_step": False,
"step_size": self.integrator.step_size,
}
aux_vars = self._init_aux_vars(state, rng)
tree = self._new_leave(state, aux_vars["h_init"], aux_vars)
next_state = state
for depth in range(self.max_tree_depth):
# uniformly sample direction to expand tree in
direction = 2 * (rng.uniform() < 0.5) - 1
state = tree.positive if direction == 1 else tree.negative
state.dir = direction
# expand tree by building new subtree of current depth
terminate, new_tree, new_proposal = self._build_tree(
depth,
state,
stats,
rng,
aux_vars,
)
if terminate:
break
# progressively sample new state by choosing between
# current new state and proposal from new subtree, biasing
# towards the new subtree proposal
accept_proposal_prob = self._weight_ratio(new_tree.weight, tree.weight)
if rng.uniform() < accept_proposal_prob:
next_state = new_proposal
# each proposal acceptance independent therefore overall probability
# of 'rejecting' - i.e. not accepting all proposals is product of
# probabilties of not accepting each proposal
stats["reject_prob"] *= 1.0 - accept_proposal_prob
# merge new subtree into current tree accounting for direction
neg_subtree = tree if direction == 1 else new_tree
pos_subtree = new_tree if direction == 1 else tree
tree = self._merge_subtrees(neg_subtree, pos_subtree)
# check termination criterion on new tree and subtrees
if self._termination_criterion(tree, neg_subtree, pos_subtree):
break
sum_accept_prob = stats.pop("sum_metrop_accept_prob")
if stats["n_step"] > 0:
stats["av_metrop_accept_prob"] = sum_accept_prob / stats["n_step"]
else:
stats["av_metrop_accept_prob"] = 0.0
if any(
stats[key]
for key in ["diverging", "convergence_error", "non_reversible_step"]
):
stats["accept_stat"] = 0.0
else:
stats["accept_stat"] = stats["av_metrop_accept_prob"]
stats["tree_depth"] = depth
return next_state, stats
[docs]class MultinomialDynamicIntegrationTransition(DynamicIntegrationTransition):
"""Dynamic integration transition with multinomial sampling of new state.
In each transition a binary tree of states is recursively computed by integrating
randomly forward and backward in time by a number of steps equal to the previous
tree size (Hoffman and Gelman, 2014; Betancourt, 2017) until a termination criteria
on the tree leaves is met. The next chain state is chosen from the candidate states
using a progressive multinomial sampling scheme (Betancourt, 2017) based on the
relative probability densities of the different candidate states, with the sampling
biased towards states further from the current state.
References:
1. Hoffman, M.D. and Gelman, A. (2014). The No-U-turn sampler: adaptively setting
path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research,
15(1), pp.1593-1623.
2. Betancourt, M. (2017). A conceptual introduction to Hamiltonian Monte Carlo.
arXiv preprint arXiv:1701.02434.
"""
def _weight_function(
self,
h: ScalarLike,
aux_vars: dict[str, ScalarLike],
) -> ScalarLike:
return LogRepFloat(log_val=-h)
def _weight_ratio(
self,
numerator: ScalarLike,
denominator: ScalarLike,
) -> ScalarLike:
return min(numerator / denominator, 1)
def _check_divergence(self, h: ScalarLike, aux_vars: dict[str, ScalarLike]):
if h - aux_vars["h_init"] > self.max_delta_h:
msg = f"delta_h = {h - aux_vars['h_init']}"
raise HamiltonianDivergenceError(msg)
[docs]class SliceDynamicIntegrationTransition(DynamicIntegrationTransition):
"""Dynamic integration transition with slice sampling of new state.
In each transition a binary tree of states is recursively computed by integrating
randomly forward and backward in time by a number of steps equal to the previous
tree size until a termination criteria on the tree leaves is met. The next chain
state is chosen from the candidate states using a progressive slice sampling scheme
based on the relative probability densities of the different candidate states, with
the slice sampler biased towards states further from the current state.
When used with the :py:class:`euclidean_no_u_turn_criterion` this transition is
equivalent to the transitions in 'Algorithm 3: Efficient No-U-Turn Sampler' in
Hoffman and Gelman (2014).
References:
1. Hoffman, M.D. and Gelman, A. (2014). The No-U-turn sampler: adaptively setting
path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research,
15(1), pp.1593-1623.
"""
def _init_aux_vars(
self,
state: ChainState,
rng: Generator,
) -> dict[str, ScalarLike]:
aux_vars = super()._init_aux_vars(state, rng)
aux_vars["log_u"] = np.log(rng.uniform()) - aux_vars["h_init"]
return aux_vars
def _weight_function(
self,
h: ScalarLike,
aux_vars: dict[str, ScalarLike],
) -> ScalarLike:
return (aux_vars["log_u"] <= -h) * 1
def _weight_ratio(
self,
numerator: ScalarLike,
denominator: ScalarLike,
) -> ScalarLike:
return min(numerator / denominator, 1) if denominator > 0 else min(numerator, 1)
def _check_divergence(self, h: ScalarLike, aux_vars: dict[str, ScalarLike]):
if h + aux_vars["log_u"] > self.max_delta_h:
msg = f"delta_h = {h + aux_vars['log_u']}"
raise HamiltonianDivergenceError(msg)