Source code for mici.autodiff
"""Automatic differentation fallback for constructing derivative functions."""
from __future__ import annotations
from typing import TYPE_CHECKING
from mici import autograd_wrapper
if TYPE_CHECKING:
from typing import Callable, Optional
"""List of names of valid differential operators.
Any automatic differentiation framework wrapper module will need to provide all of these
operators as callables (with a single function as argument) to fully support all of the
required derivative functions.
"""
DIFF_OPS = [
# vector Jacobian product and value
"vjp_and_value",
# gradient and value for scalar valued functions
"grad_and_value",
# Hessian matrix, gradient and value for scalar valued functions
"hessian_grad_and_value",
# matrix Tressian product, gradient and value for scalar valued functions
"mtp_hessian_grad_and_value",
# Jacobian matrix and value for vector valued functions
"jacobian_and_value",
# matrix Hessian product, Jacobian matrix and value for vector valued functions
"mhp_jacobian_and_value",
]
[docs]def autodiff_fallback(
diff_func: Optional[Callable],
func: Callable,
diff_op_name: str,
name: str,
) -> Callable:
"""Generate derivative function automatically if not provided.
Uses automatic differentiation to generate a function corresponding to a
differential operator applied to a function if an alternative implementation of the
derivative function has not been provided.
Args:
diff_func: Either a callable implementing the required derivative function or
`None` if none was provided.
func: Function to differentiate.
diff_op_name: String specifying name of differential operator from automatic
differentiation framework wrapper to use to generate required derivative
function.
name: Name of derivative function to use in error message.
Returns:
`diff_func` value if not `None` otherwise generated derivative of `func` by
applying named differential operator.
"""
if diff_func is not None:
return diff_func
elif diff_op_name not in DIFF_OPS:
msg = f"Differential operator {diff_op_name} is not defined."
raise ValueError(msg)
elif autograd_wrapper.AUTOGRAD_AVAILABLE:
return getattr(autograd_wrapper, diff_op_name)(func)
elif not autograd_wrapper.AUTOGRAD_AVAILABLE:
msg = f"Autograd not available therefore {name} must be provided."
raise ValueError(msg)
return None