mici.autodiff package#

Automatic differentation support for constructing derivative functions.

Multiple automatic differentiation backends are supported:

  • jax: High performance array computing framework, with support for running computations on accelerator devices and just-in-time (JIT) compilation. To differentiate a function using the jax backend, the function must be defined in terms of JAX primitives, for example using the functions in the jax.numpy API. By default the derivative functions produced will be JIT compiled; a jax_nojit variant is available if the function (or its derivative) is not compatible with JIT compilation.

  • autograd: Autograd can automatically differentiate native Python and NumPy code. To differentiate a function using the autograd backend it should be defined in terms of functions from the autograd.numpy and autograd.scipy APIs. Compared to JAX, the lack of JIT compilation in Autograd and features such as automatic vectorisation make Autograd slower, and so JAX will generally be a better choice.

  • symnum: SymNum is a Python package that acts a bridge between NumPy and SymPy, providing a NumPy-like interface that can be used to symbolically define functions which take arrays as arguments and return arrays or scalars as values. To differentiate a function using the symnum backend it should be defined in terms of functions from the symnum.numpy API, and should have been decorated with symnum.numpify with specified argument shapes. SymNum is intended for use in generating the derivatives of ‘simple’ functions which compose a relatively small number of operations and act on small array inputs. By reducing interpreter overheads it can produce code which is cheaper to evaluate than corresponding Autograd or JAX functions (including those using JIT compilation) in such cases, and which can be serialised with the inbuilt Python pickle library allowing use for example in libraries which use multiprocessing to implement parallelisation across multiple processes.

class mici.autodiff.AutodiffBackend(module, available, function_wrapper=None)[source]#

Bases: NamedTuple

Automatic differentiation backend framework.

Consists of a module defining differential operators, a boolean flag indicating if backend is available in current environment and optionally a function wrapper which applies any post processing required to functions.

Parameters:
  • module (ModuleType)

  • available (bool)

  • function_wrapper (Callable | None)

mici.autodiff.autodiff_fallback(diff_func, func, diff_op_name, name, backend)[source]#

Generate derivative function automatically if not provided.

Uses automatic differentiation to generate a function corresponding to a differential operator applied to a function if an alternative implementation of the derivative function has not been provided.

Parameters:
  • diff_func (Callable | None) – Either a callable implementing the required derivative function or None if none was provided.

  • func (Callable) – Function to differentiate.

  • diff_op_name (str) – String specifying name of differential operator from automatic differentiation framework wrapper to use to generate required derivative function.

  • name (str) – Name of derivative function to use in error message.

  • backend (str | None) – Name of automatic differentiation framework backend to use. If None diff_func must be provided.

Returns:

diff_func value if not None otherwise generated derivative of func by applying named differential operator from automatic differentiation backend.

Return type:

Callable

mici.autodiff.wrap_function(function, backend)[source]#

Apply function wrapper for automatic differentiation backend to a function.

Backends may define a function wrapper which applies any post processing required to functions using framework - for example ensuring the function returns NumPy arrays or just-in-time compiling the function.

Parameters:
  • function (Callable) – Function to wrap.

  • backend (str | None) – Name of automatic differentiation framework backend to use. If None function is returned unchanged.

Returns:

Wrapped function.

Return type:

Callable

Submodules#