Source code for mici.utils

"""Utility functions and classes."""

from __future__ import annotations

from math import exp, expm1, inf, log, log1p, nan
from typing import TYPE_CHECKING

import numpy as np

try:
    import xxhash

    XXHASH_AVAILABLE = True
except ImportError:
    XXHASH_AVAILABLE = False

if TYPE_CHECKING:
    from typing import Optional

    from mici.types import ScalarLike


[docs]def hash_array(array: np.ndarray) -> int: """Compute hash of a NumPy array by hashing data as a byte sequence. Args: array: NumPy array to compute hash of. Returns: Computed hash as an integer. """ if XXHASH_AVAILABLE: # If fast Python wrapper of fast xxhash implementation is available use # in preference to built in hash function h = xxhash.xxh64() # Update hash by viewing array as byte sequence - no copy required h.update(array.view(np.byte).data) # Also update hash by array dtype, shape and strides to avoid clashes # between different views of same array h.update(bytes(f"{array.dtype}{array.shape}{array.strides}", "utf-8")) return h.intdigest() else: # Evaluate built-in hash function on *copy* of data as a byte sequence return hash(array.tobytes())
LOG_2: float = log(2.0)
[docs]def log1p_exp(val: float) -> float: """Numerically stable implementation of `log(1 + exp(val))`.""" if val > 0.0: return val + log1p(exp(-val)) else: return log1p(exp(val))
[docs]def log1m_exp(val: float) -> float: """Numerically stable implementation of `log(1 - exp(val))`.""" if val >= 0.0: return nan elif val > LOG_2: return log(-expm1(val)) else: return log1p(-exp(val))
[docs]def log_sum_exp(val1: float, val2: float) -> float: """Numerically stable implementation of `log(exp(val1) + exp(val2))`.""" if val1 == -inf and val2 == -inf: return -inf elif val1 > val2: return val1 + log1p_exp(val2 - val1) else: return val2 + log1p_exp(val1 - val2)
[docs]def log_diff_exp(val1: float, val2: float) -> float: """Numerically stable implementation of `log(exp(val1) - exp(val2))`.""" if val1 == -inf and val2 == -inf: return -inf elif val1 < val2: return nan elif val1 == val2: return -inf else: return val1 + log1m_exp(val2 - val1)
[docs]class LogRepFloat: """Numerically stable logarithmic representation of positive float values. Stores logarithm of value and overloads arithmetic operators to use more numerically stable implementations where possible. """ def __init__(self, val: Optional[float] = None, log_val: Optional[float] = None): if log_val is None: if val is None: msg = "One of val or log_val must be specified." raise ValueError(msg) if val > 0: self.log_val = log(val) elif val == 0.0: self.log_val = -inf else: msg = "val must be non-negative." raise ValueError(msg) else: if val is not None: msg = "Specify only one of val and log_val." raise ValueError(msg) self.log_val = log_val @property def val(self) -> float: try: return exp(self.log_val) except OverflowError: return inf def __add__(self, other: ScalarLike) -> ScalarLike: if isinstance(other, LogRepFloat): return LogRepFloat(log_val=log_sum_exp(self.log_val, other.log_val)) else: return self.val + other def __radd__(self, other: ScalarLike) -> ScalarLike: return self.__add__(other) def __iadd__(self, other: ScalarLike): if other == 0: return self elif isinstance(other, LogRepFloat): self.log_val = log_sum_exp(self.log_val, other.log_val) else: self.log_val = log_sum_exp(self.log_val, log(other)) return self def __sub__(self, other: ScalarLike) -> ScalarLike: if isinstance(other, LogRepFloat): if self.log_val >= other.log_val: return LogRepFloat(log_val=log_diff_exp(self.log_val, other.log_val)) else: return self.val - other.val else: return self.val - other def __rsub__(self, other: ScalarLike) -> ScalarLike: return (-self).__radd__(other) def __mul__(self, other: ScalarLike) -> ScalarLike: if isinstance(other, LogRepFloat): return LogRepFloat(log_val=self.log_val + other.log_val) else: return self.val * other def __rmul__(self, other: ScalarLike) -> ScalarLike: return self.__mul__(other) def __truediv__(self, other: ScalarLike) -> ScalarLike: if isinstance(other, LogRepFloat): return LogRepFloat(log_val=self.log_val - other.log_val) else: return self.val / other def __rtruediv__(self, other: ScalarLike) -> ScalarLike: return other / self.val def __neg__(self) -> float: return -self.val def __eq__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val == other.log_val else: return self.val == other def __ne__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val != other.log_val else: return self.val != other def __lt__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val < other.log_val else: return self.val < other def __gt__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val > other.log_val else: return self.val > other def __le__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val <= other.log_val else: return self.val <= other def __ge__(self, other: ScalarLike) -> bool: if isinstance(other, LogRepFloat): return self.log_val >= other.log_val else: return self.val >= other def __str__(self) -> str: return str(self.val) def __repr__(self) -> str: return f"LogRepFloat(val={self.val})" def __array__(self) -> np.ndarray: return np.array(self.val)