Source code for mici.states

"""Objects for recording state of a Markov chain and caching computations."""

from __future__ import annotations

import copy
from collections import Counter
from functools import wraps
from typing import TYPE_CHECKING

from mici.errors import ReadOnlyStateError

if TYPE_CHECKING:
    from collections.abc import Iterable
    from typing import Any, Callable, Optional

    from numpy.typing import ArrayLike

    from mici.systems import System


def _cache_key_func(system: System, method: Callable) -> tuple[str, int]:
    """Construct cache key for a given system and method pair."""
    if not isinstance(method, str):
        method = method.__name__
    return (f"{type(system).__name__}.{method}", id(system))


[docs]def cache_in_state( *depends_on: str, ) -> Callable[ [Callable[[System, ChainState], ArrayLike]], Callable[[System, ChainState], ArrayLike], ]: """Memoizing decorator for system methods. Used to decorate `mici.systems.System` methods which compute a function of one or more chain state variable(s), with the decorated method caching the value returned by the method being wrapped in the `ChainState` object to prevent the need for recomputation on future calls if the state variables the returned value depends on have not been changed in between the calls. Additionally for `ChainState` instances initialized with a `_call_counts` argument, the memoized method will update a counter for the method in the `_call_counts` attribute every time the method being decorated is called (i.e. when there isn't a valid cached value available). Args: *depends_on: One or more strings corresponding to the names of any state variables the value returned by the method depends on, e.g. `pos` or `mom`, such that the cache in the state object is correctly cleared when the value of any of these variables (attributes) of the state object changes. """ def cache_in_state_decorator(method): @wraps(method) def wrapper(self, state): key = _cache_key_func(self, method) if key not in state._cache: for dep in depends_on: state._dependencies[dep].add(key) if key not in state._cache or state._cache[key] is None: state._cache[key] = method(self, state) if state._call_counts is not None: state._call_counts[key] += 1 return state._cache[key] return wrapper return cache_in_state_decorator
[docs]def cache_in_state_with_aux( depends_on: Iterable[str], auxiliary_outputs: Iterable[str], ) -> Callable[ [Callable[[System, ChainState], ArrayLike]], Callable[[System, ChainState], ArrayLike], ]: """Memoizing decorator for system methods with possible auxiliary outputs. Used to decorate `System` methods which compute a function of one or more chain state variable(s), with the decorated method caching the value or values returned by the method being wrapped in the `ChainState` object to prevent the need for recomputation on future calls if the state variables the returned value(s) depends on have not been changed in between the calls. Compared to the `cache_in_state` decorator, this variant allows for methods which may optionally also return additional auxiliary outputs, such as intermediate result computed while computing the primary output, which correspond to the output of another system method decorated with the `cache_in_state` or `cache_in_state_with_aux` decorators. If such auxiliary outputs are returned they are also used to update cache entry for the corresponding decorated method, potentially saving recomputation in subsequent calls to that method. A common instance of this pattern is in derivative values computed using automatic differentiation (AD), with the primal value being differentiated usually either calculated alongside the derivative (in forward-mode AD) or calculated first in a forward-pass before the derivatives are calculated in a reverse-pass (in reverse-mode AD). By caching the value of the primal computed as part of the derivative calculation, a subsequent call to a method corresponding to calculation of the primal itself will retrieve the cached value and not recompute the primal, providing the relevant state variables the primal (and derivative) depend on have not been changed in between. Additionally for `ChainState` instances initialized with a `_call_counts` argument, the memoized method will update a counter for the method in the `_call_counts` attribute every time the method being decorated is called (i.e. when there isn't a valid cached value available). Args: depends_on: A string or tuple of strings, with each string corresponding to the name of a state variables the value(s) returned by the method depends on, e.g. 'pos' or 'mom', such that the cache in the state object is correctly cleared when the value of any of these variables (attributes) of the state object changes. auxiliary_outputs: A string or tuple of strings, with each string defining an auxiliary output the wrapped method may additionally return in addition to the primary output. If auxiliary outputs are returned, the returned value should be a tuple with first entry the 'primary' output corresponding to the value associated with the name of the method and the subsequent entries in the tuple corresponding to the auxiliary outputs in the order specified by the entries in the `auxiliary_outputs` argument. If the primary output is itself a tuple, it must be wrapped in another tuple even when no auxiliary outputs are being returned. """ if isinstance(depends_on, str): depends_on = (depends_on,) if isinstance(auxiliary_outputs, str): auxiliary_outputs = (auxiliary_outputs,) def cache_in_state_with_aux_decorator(method): @wraps(method) def wrapper(self, state): prim_key = _cache_key_func(self, method) keys = [prim_key] + [_cache_key_func(self, a) for a in auxiliary_outputs] for _i, key in enumerate(keys): if key not in state._cache: for dep in depends_on: state._dependencies[dep].add(key) if prim_key not in state._cache or state._cache[prim_key] is None: vals = method(self, state) if isinstance(vals, tuple): for k, v in zip(keys, vals): state._cache[k] = v else: state._cache[prim_key] = vals if state._call_counts is not None: state._call_counts[prim_key] += 1 return state._cache[prim_key] return wrapper return cache_in_state_with_aux_decorator
[docs]class ChainState: """Markov chain state. As well as recording the chain state variable values, the state object is also used to cache derived quantities to avoid recalculation if these values are subsequently reused. Additionally for `ChainState` instances initialized with a `_call_counts` dictionary, any memoized system methods (i.e. those decorated with `cache_in_state` or `cache_in_state_with_aux`) will update a counter for the method in the state `_call_counts` dictionary attribute every time the decorated method is called (i.e. when there isn't a valid cached value available). """ def __init__( self, *, _call_counts: Optional[dict[str, int]] = None, _read_only: bool = False, _dependencies: Optional[dict[str, set[str]]] = None, _cache: Optional[dict[str, Any]] = None, **variables: ArrayLike, ): """Create a new `ChainState` instance. Any keyword arguments passed to the constructor (with names not starting with an underscore) will be used to set state variable attributes of state object for example state = ChainState(pos=pos_val, mom=mom_val, dir=dir_val) will return a `ChainState` instance `state` with variable attributes `state.pos`, `state.mom` and `state.dir` with initial values set to `pos_val`, `mom_val` and `dir_val` respectively. Keyword arguments with a leading underscore in the name are reserved for additional arguments to the constructor not corresponding to state variables. Additionally the name `copy` should not be used as attribute access to this name will be blocked by the `copy` method. Args: **variables: Keyword arguments corresponding to state variables. All names must not begin with an underscore and no name can be `copy`. See description above for details. _call_counts: If a dictionary (or `Counter`) is passed this will be used to store counts of the number of calls of system methods decorated with `cache_in_state` or `cache_in_state_with_aux` when called on this state object and when no cached value for the method is available so that the wrapped method is called. The `_call_counts` attribute persists between all copies of a state so will count any decorated method calls on copies of the state as well - e.g. all copies of a state in a sampled Markov chain, allowing the `_call_counts` attribute to be used to monitor the number of method call while sampling a chain. _read_only: If `True` a `mici.errors.ReadOnlyStateError` exception will be raised when attempting to set any attributes of the state object after construction. Defaults to `False`. _dependencies: Intended for internal use only. If not `None` this should be a dictionary with string keys corresponding to the state variable names and values which are sets of strings indicating any dependencies of the relevant state variable in the cache. _cache: Intended for internal use only. If not `None` this should be a dictionary with keys corresponding to unique identifiers for methods decorated with the `cache_in_state` or `cache_in_state_with_aux` decorators and values corresponding to cached computed outputs of these methods or `None` for when a cached output is not available. """ # Set attributes by directly writing to __dict__ to ensure set before # any call to __setattr__ self.__dict__["_variables"] = variables if _dependencies is None: _dependencies = {name: set() for name in variables} self.__dict__["_dependencies"] = _dependencies if _cache is None: _cache = {} self.__dict__["_cache"] = _cache self.__dict__["_call_counts"] = ( Counter(_call_counts) if _call_counts is None or not isinstance(_call_counts, Counter) else _call_counts ) self.__dict__["_read_only"] = _read_only def __getattr__(self, name: str) -> ArrayLike: if name in self._variables: return self._variables[name] else: msg = f"'{type(self).__name__}' object has no attribute '{name}'" raise AttributeError(msg) def __setattr__(self, name: str, value: ArrayLike): if self._read_only: msg = "ChainState instance is read-only." raise ReadOnlyStateError(msg) if name in self._variables: self._variables[name] = value # clear any dependent cached values for dep in self._dependencies[name]: self._cache[dep] = None return None else: return super().__setattr__(name, value) def __contains__(self, name: str) -> bool: return name in self._variables
[docs] def copy(self, *, read_only: bool = False) -> ChainState: """Create a deep copy of the state object. Args: read_only: Whether the state copy should be read-only. Returns: A copy of the state object with variable attributes that are independent copies of the original state object's variables. """ return type(self)( _dependencies=self._dependencies, _cache=self._cache.copy(), _call_counts=self._call_counts, _read_only=read_only, **{name: copy.copy(val) for name, val in self._variables.items()}, )
def __str__(self) -> str: return ( "(\n " + ",\n ".join([f"{k}={v}" for k, v in self._variables.items()]) + ")" ) def __repr__(self) -> str: return type(self).__name__ + str(self) def __getstate__(self) -> dict[str, Any]: return { "variables": self._variables, "dependencies": self._dependencies, # Don't pickle callable cached 'variables' such as derivative # functions "cache": {k: v for k, v in self._cache.items() if not callable(v)}, "call_counts": self._call_counts, "read_only": self._read_only, } def __setstate__(self, state: dict[str, Any]): self.__dict__["_variables"] = state["variables"] self.__dict__["_dependencies"] = state["dependencies"] self.__dict__["_cache"] = state["cache"] self.__dict__["_call_counts"] = state["call_counts"] self.__dict__["_read_only"] = state["read_only"]