mici.autodiff.jax_wrapper module#
JAX differential operators and helper functions.
- mici.autodiff.jax_wrapper.grad_and_value(func)[source]#
Makes a function that returns both the gradient and value of a function.
- Parameters:
func (ScalarFunction)
- Return type:
GradientFunction
- mici.autodiff.jax_wrapper.hessian_grad_and_value(func)[source]#
Makes a function that returns the Hessian, gradient and value of a function.
- Parameters:
func (ScalarFunction)
- Return type:
HessianFunction
- mici.autodiff.jax_wrapper.jacobian_and_value(func)[source]#
Makes a function that returns both the Jacobian and value of a function.
- Parameters:
func (ArrayFunction)
- Return type:
JacobianFunction
- mici.autodiff.jax_wrapper.jit_and_return_numpy_arrays(function)[source]#
Wrap a JIT compiled function returning JAX arrays to instead return NumPy arrays.
- Parameters:
function (Callable[[~P], Array | Callable[[~P], JaxArray | JaxArrayFunction | tuple[JaxArray | JaxArrayFunction, ...]] | tuple[Array | Callable[[~P], JaxArray | JaxArrayFunction | tuple[JaxArray | JaxArrayFunction, ...]], ...]]) – Function to wrap. Should return one of: a single JAX array, a callable returning one or more JAX array or a tuple of one or more JAX arrays or functions returning one or more JAX arrays.
**jit_kwargs – Any keyword arguments to pass to jax.jit operator.
- Returns:
Wrapped function. Any values returned by original function which are JAX arrays will instead be NumPy arrays, while any values which are callables returning JAX arrays will instead return NumPy arrays.
- Return type:
Callable[[~P], ndarray | Callable[[~P], np.ndarray | NumPyArrayFunction | tuple[np.ndarray | NumPyArrayFunction, …]] | tuple[ndarray | Callable[[~P], np.ndarray | NumPyArrayFunction | tuple[np.ndarray | NumPyArrayFunction, …]], …]]
- mici.autodiff.jax_wrapper.mhp_jacobian_and_value(func)[source]#
Makes a function that returns MHP, Jacobian and value of a function.
For a vector-valued function fun the matrix-Hessian-product (MHP) is here defined as a function of a matrix m corresponding to
mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1))
where h is the vector-Hessian of f = fun(x) wrt x i.e. the rank-3 tensor of second-order partial derivatives of the vector-valued function, such that
h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k])
- Parameters:
func (ArrayFunction)
- Return type:
MatrixHessianProductFunction
- mici.autodiff.jax_wrapper.mtp_hessian_grad_and_value(func)[source]#
Makes a function that returns MTP, Jacobian and value of a function.
For a scalar-valued function fun the matrix-Tressian-product (MTP) is here defined as a function of a matrix m corresponding to
mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2))
where t is the ‘Tressian’ of f = fun(x) wrt x i.e. the 3D array of third-order partial derivatives of the scalar-valued function such that
t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k])
- Parameters:
func (ScalarFunction)
- Return type:
MatrixTressianProductFunction
- mici.autodiff.jax_wrapper.return_numpy_arrays(function)[source]#
Wrap a function returning JAX arrays to instead return NumPy arrays.
- Parameters:
function (Callable[[~P], Array | Callable[[~P], JaxArray | JaxArrayFunction | tuple[JaxArray | JaxArrayFunction, ...]] | tuple[Array | Callable[[~P], JaxArray | JaxArrayFunction | tuple[JaxArray | JaxArrayFunction, ...]], ...]]) – Function to wrap. Should return one of: a single JAX array, a callable returning one or more JAX array or a tuple of one or more JAX arrays or functions returning one or more JAX arrays.
- Returns:
Wrapped function. Any values returned by original function which are JAX arrays will instead be NumPy arrays, while any values which are callables returning JAX arrays will instead return NumPy arrays.
- Return type:
Callable[[~P], ndarray | Callable[[~P], np.ndarray | NumPyArrayFunction | tuple[np.ndarray | NumPyArrayFunction, …]] | tuple[ndarray | Callable[[~P], np.ndarray | NumPyArrayFunction | tuple[np.ndarray | NumPyArrayFunction, …]], …]]
- mici.autodiff.jax_wrapper.vjp_and_value(func)[source]#
Makes a function that returns vector-Jacobian-product and value of a function.
For a vector-valued function fun the vector-Jacobian-product (VJP) is here defined as a function of a vector v corresponding to
vjp(v) = v @ j
where j is the Jacobian of f = fun(x) wrt x i.e. the rank-2 tensor of first-order partial derivatives of the vector-valued function, such that
j[i, k] = ∂f[i] / ∂x[k]
- Parameters:
func (ArrayFunction)
- Return type:
VectorJacobianProductFunction