Source code for mici.progressbars

"""Progress bar classes for tracking progress of chains."""

from __future__ import annotations

import abc
import html
import importlib
import sys
from timeit import default_timer as timer
from typing import TYPE_CHECKING

try:
    from IPython import get_ipython
    from IPython.display import display as ipython_display

    IPYTHON_AVAILABLE = True
except ImportError:
    IPYTHON_AVAILABLE = False

if TYPE_CHECKING:
    from collections.abc import Collection, Generator
    from queue import Queue
    from typing import Any, Optional, TextIO


ON_COLAB = (
    importlib.util.find_spec("google") is not None
    and importlib.util.find_spec("google.colab") is not None
)


def _in_zmq_interactive_shell() -> bool:
    """Check if in interactive ZMQ shell which supports updateable displays."""
    if not IPYTHON_AVAILABLE:
        return False
    elif ON_COLAB:
        return True
    else:
        try:
            shell = get_ipython().__class__.__name__
            if shell == "ZMQInteractiveShell":
                return True
            elif shell == "TerminalInteractiveShell":
                return False
            else:
                return False
        except NameError:
            return False


def _create_display(obj, position: tuple[int, int]):
    """Create an updateable display object.

    Args:
        obj: Initial object to display.
        position: Tuple specifying position of display within a sequence of displays
            with first entry corresponding to the zero-indexed position and the second
            entry the total number of displays.

    Returns:
        Object with `update` method to update displayed content.
    """
    if _in_zmq_interactive_shell():
        return ipython_display(obj, display_id=True)
    else:
        return FileDisplay(position)


def _format_time(total_seconds: float) -> str:
    """Format a time interval in seconds as a colon-delimited string [h:]m:s."""
    total_mins, seconds = divmod(int(total_seconds), 60)
    hours, mins = divmod(total_mins, 60)
    if hours != 0:
        return f"{hours:d}:{mins:02d}:{seconds:02d}"
    else:
        return f"{mins:02d}:{seconds:02d}"


def _update_stats_running_means(
    iter_: int,
    means: dict[str, float],
    new_vals: dict[str, float],
):
    """Update dictionary of running statistics means with latest values."""
    if iter_ == 1:
        means.update({key: float(val) for key, val in new_vals.items()})
    else:
        for key, val in new_vals.items():
            means[key] += (float(val) - means[key]) / iter_


[docs]class ProgressBar(abc.ABC): """Base class defining expected interface for progress bars.""" def __init__( self, sequence: Collection, description: Optional[str], position: tuple[int, int] = (0, 1), ): """ Args: sequence: Sequence to iterate over. Must be iterable _and_ have a defined length such that `len(sequence)` is valid. description: Description of task to prefix progress bar with. position: Tuple specifying position of progress bar within a sequence with first entry corresponding to zero-indexed position and the second entry the total number of progress bars. """ self._sequence = sequence self._description = description self._position = position self._active = False self._n_iter = len(sequence) @property def sequence(self) -> Collection: """Sequence iterated over.""" return self._sequence @sequence.setter def sequence(self, value: Collection): if self._active: msg = "Cannot set sequence of active progress bar." raise RuntimeError(msg) self._sequence = value self._n_iter = len(value) @property def n_iter(self) -> int: return self._n_iter def __iter__(self) -> Generator[tuple[Any, dict[str, float]], None, None]: for i, val in enumerate(self.sequence): iter_dict = {} yield val, iter_dict self.update(i + 1, iter_dict, refresh=True) def __len__(self) -> int: return self._n_iter
[docs] @abc.abstractmethod def update( self, iter_count: int, iter_dict: Optional[dict[str, float]], *, refresh: bool = True, ): """Update progress bar state. Args: iter_count: New value for iteration counter. iter_dict: Dictionary of iteration statistics key-value pairs to use to update postfix stats. refresh: Whether to refresh display(s). """
def __enter__(self): """Set up progress bar and any associated resource.""" self._active = True return self def __exit__(self, *args) -> bool: """Close down progress bar and any associated resources.""" self._active = False return False
[docs]class DummyProgressBar(ProgressBar): """Placeholder progress bar which does not display progress updates."""
[docs] def update( self, iter_count: int, iter_dict: Optional[dict[str, float]], *, refresh: bool = True, ): pass
[docs]class SequenceProgressBar(ProgressBar): """Iterable object for tracking progress of an iterative task. Implements both string and HTML representations to allow richer display in interfaces which support HTML output, for example Jupyter notebooks or interactive terminals. """ GLYPHS = " ▏▎▍▌▋▊▉█" """Characters used to create string representation of progress bar.""" def __init__( self, sequence: Collection, description: Optional[str] = None, position: tuple[int, int] = (0, 1), displays: Optional[Collection] = None, n_col: int = 10, unit: str = "it", min_refresh_time: float = 0.25, ): """ Args: sequence: Sequence to iterate over. Must be iterable **and** have a defined length such that `len(sequence)` is valid. description: Description of task to prefix progress bar with. position: Tuple specifying position of progress bar within a sequence with first entry corresponding to zero-indexed position and the second entry the total number of progress bars. displays: List of objects to use to display visual representation(s) of progress bar. Each object much have an `update` method which will be passed a single argument corresponding to the current progress bar. n_col: Number of columns (characters) to use in string representation of progress bar. unit: String describing unit of per-iteration tasks. min_referesh_time: Minimum time in seconds between each refresh of progress bar visual representation. """ super().__init__(sequence, description, position) self._n_col = n_col self._unit = unit self._counter = 0 self._start_time = None self._elapsed_time = 0 self._stats_dict = {} self._displays = displays self._min_refresh_time = min_refresh_time @property def description(self) -> str: """Description of task being tracked.""" return self._description @property def counter(self) -> int: """Progress iteration count.""" return self._counter @counter.setter def counter(self, value: int): self._counter = max(0, min(value, self.n_iter)) @property def prop_complete(self) -> float: """Proportion complete (float value in [0, 1]).""" return self.counter / self.n_iter @property def perc_complete(self) -> str: """Percentage complete formatted as string.""" return f"{int(self.prop_complete * 100):3d}%" @property def elapsed_time(self) -> str: """Elapsed time formatted as string.""" return _format_time(self._elapsed_time) @property def iter_rate(self) -> str: """Mean iteration rate if ≥ 1 `it/s` or reciprocal `s/it` as string.""" if self.prop_complete == 0: return "?" else: mean_time = self._elapsed_time / self.counter return ( f"{mean_time:.2f}s/{self._unit}" if mean_time > 1 else f"{1/mean_time:.2f}{self._unit}/s" ) @property def est_remaining_time(self) -> str: """Estimated remaining time to completion formatted as string.""" if self.prop_complete == 0: return "?" else: return _format_time((1 / self.prop_complete - 1) * self._elapsed_time) @property def n_block_filled(self) -> int: """Number of filled blocks in progress bar.""" return int(self._n_col * self.prop_complete) @property def n_block_empty(self) -> int: """Number of empty blocks in progress bar.""" return self._n_col - self.n_block_filled @property def prop_partial_block(self) -> float: """Proportion filled in partial block in progress bar.""" return self._n_col * self.prop_complete - self.n_block_filled @property def filled_blocks(self) -> str: """Filled blocks string.""" return self.GLYPHS[-1] * self.n_block_filled @property def empty_blocks(self) -> str: """Empty blocks string.""" if self.prop_partial_block == 0: return self.GLYPHS[0] * self.n_block_empty else: return self.GLYPHS[0] * (self.n_block_empty - 1) @property def partial_block(self) -> str: """Partial block character.""" if self.prop_partial_block == 0: return "" else: return self.GLYPHS[int(len(self.GLYPHS) * self.prop_partial_block)] @property def progress_bar(self) -> str: """Progress bar string.""" return f"|{self.filled_blocks}{self.partial_block}{self.empty_blocks}|" @property def bar_color(self) -> str: """CSS color property for HTML progress bar.""" if self.counter == self.n_iter: return "var(--jp-success-color1, #4caf50)" elif self._active: return "var(--jp-brand-color1, #2196f3)" else: return "var(--jp-error-color1, #f44336)" @property def stats(self) -> str: """Comma-delimited string list of statistic key=value pairs.""" return ", ".join(f"{k}={v:#.3g}" for k, v in self._stats_dict.items()) @property def prefix(self) -> str: """Text to prefix progress bar with.""" return ( f'{self.description + ": "if self.description else ""}' f"{self.perc_complete}" ) @property def postfix(self) -> str: """Text to postfix progress bar with.""" return ( f"{self.counter}/{self.n_iter} " f"[{self.elapsed_time}<{self.est_remaining_time}, " f"{self.iter_rate}" f'{", " + self.stats if self._stats_dict else ""}]' )
[docs] def reset(self): """Reset progress bar state.""" self._counter = 0 self._start_time = timer() self._last_refresh_time = -float("inf") self._stats_dict = {}
[docs] def update( self, iter_count: int, iter_dict: Optional[dict[str, float]] = None, *, refresh: bool = True, ): """Update progress bar state. Args: iter_count: New value for iteration counter. iter_dict: Dictionary of iteration statistics key-value pairs to use to update postfix stats. refresh: Whether to refresh display(s). """ if iter_count == 0: self.reset() else: self.counter = iter_count if iter_dict is not None: _update_stats_running_means(iter_count, self._stats_dict, iter_dict) self._elapsed_time = timer() - self._start_time if ( refresh and iter_count == self.n_iter or (timer() - self._last_refresh_time > self._min_refresh_time) ): self.refresh() self._last_refresh_time = timer()
[docs] def refresh(self): """Refresh visual display(s) of progress bar.""" for display in self._displays: display.update(self)
def __str__(self) -> str: return f"{self.prefix}{self.progress_bar}{self.postfix}" def __repr__(self) -> str: return self.__str__() def _repr_html_(self) -> str: return f""" <div style="line-height: 28px; width: 100%; display: flex; flex-flow: row wrap; align-items: center; position: relative; margin: 0px;"> <label style="margin-right: 8px; flex-shrink: 0; font-size: var(--jp-code-font-size, 13px); font-family: var(--jp-code-font-family, monospace);"> {html.escape(self.prefix).replace(' ', '&nbsp;')} </label> <div role="progressbar" aria-valuenow="{self.prop_complete}" aria-valuemin="0" aria-valuemax="1" style="position: relative; flex-grow: 1; align-self: stretch; margin-top: 4px; margin-bottom: 4px; height: initial; background-color: #eee;"> <div style="background-color: {self.bar_color}; position: absolute; bottom: 0; left: 0; width: {self.perc_complete}; height: 100%;"></div> </div> <div style="margin-left: 8px; flex-shrink: 0; font-family: var(--jp-code-font-family, monospace); font-size: var(--jp-code-font-size, 13px);"> {html.escape(self.postfix)} </div> </div> """ def __enter__(self): super().__enter__() self.reset() if self._displays is None: self._displays = [_create_display(self, self._position)] return self def __exit__(self, *args) -> bool: ret_val = super().__exit__() if self.counter != self.n_iter: self.refresh() return ret_val
[docs]class LabelledSequenceProgressBar(ProgressBar): """Iterable object for tracking progress of a sequence of labelled tasks.""" def __init__( self, labelled_sequence: dict[str, Any], description: Optional[str] = None, position: tuple[int, int] = (0, 1), displays: Optional[Collection] = None, ): """ Args: labelled_sequence: Ordered dictionary with string keys corresponding to labels for stages represented by sequence and values the entries in the sequence being iterated over. description: Description of task to prefix progress bar with. position: Tuple specifying position of progress bar within a sequence with first entry corresponding to zero-indexed position and the second entry the total number of progress bars. displays: List of objects to use to display visual representation(s) of progress bar. Each object much have an `update` method which will be passed a single argument corresponding to the current progress bar. """ super().__init__(list(labelled_sequence.values()), description, position) self._labels = list(labelled_sequence.keys()) self._description = description self._position = position self._counter = 0 self._prev_time = None self._iter_times = [None] * self.n_iter self._stats_dict = {} self._displays = displays @property def counter(self) -> int: """Progress iteration count.""" return self._counter @counter.setter def counter(self, value: int): self._counter = max(0, min(value, self.n_iter)) @property def description(self) -> str: """Description of task being tracked.""" return self._description @property def stats(self) -> str: """Comma-delimited string list of statistic key=value pairs.""" return ", ".join(f"{k}={v:#.3g}" for k, v in self._stats_dict.items()) @property def prefix(self) -> str: """Text to prefix progress bar with.""" return f'{self.description + ": " if self.description else ""}' @property def postfix(self) -> str: """Text to postfix progress bar with.""" return f" [{self.stats}]" if self._stats_dict else "" @property def completed_labels(self) -> list[str]: """Labels corresponding to completed iterations.""" return [ f"{label} [{_format_time(time)}]" for label, time in zip( self._labels[: self._counter], self._iter_times[: self._counter], ) ] @property def current_label(self) -> str: """Label corresponding to current iteration.""" return self._labels[self._counter] if self.counter < self.n_iter else "" @property def unstarted_labels(self) -> list[str]: """Labels corresponding to unstarted iterations.""" return self._labels[self._counter + 1 :] @property def progress_bar(self) -> str: """Progress bar string.""" labels = self.completed_labels if self.counter < self.n_iter: labels.append(self.current_label) return " > ".join(labels)
[docs] def reset(self): """Reset progress bar state.""" self._counter = 0 self._prev_time = timer() self._iter_times = [None] * self.n_iter self._stats_dict = {}
[docs] def update( self, iter_count: int, iter_dict: Optional[dict[str, float]] = None, *, refresh: bool = True, ): """Update progress bar state. Args: iter_count: New value for iteration counter. iter_dict: Dictionary of iteration statistics key-value pairs to use to update postfix stats. refresh: Whether to refresh display(s). """ if iter_count == 0: self.reset() else: self.counter = iter_count if iter_dict is not None: _update_stats_running_means(iter, self._stats_dict, iter_dict) curr_time = timer() self._iter_times[iter_count - 1] = curr_time - self._prev_time self._prev_time = curr_time if refresh: self.refresh()
[docs] def refresh(self): """Refresh visual display(s) of status bar.""" for display in self._displays: display.update(self)
def __str__(self) -> str: return f"{self.prefix}{self.progress_bar}{self.postfix}" def __repr__(self) -> str: return self.__str__() def _repr_html_(self) -> str: html_string = f""" <div style="line-height: 24px; width: 100%; display: flex; flex-flow: row wrap; align-items: center; position: relative; margin: 0px;"> <label style="flex-shrink: 0; font-size: var(--jp-code-font-size, 13px); font-family: var(--jp-code-font-family, monospace);"> {html.escape(self.prefix).replace(' ', '&nbsp;')} </label> """ template_string = """ <div style="position: relative; flex-grow: 1; align-self: stretch; margin: 1px; padding: 0px; text-align: center; height: initial; background-color: {background_color}; color: {foreground_color}; border-radius: 5px; border: 1px solid {foreground_color}; font-size: 90%;"> {label} </div> """ for label in self.completed_labels: html_string += template_string.format( label=label, foreground_color="white", background_color="#4caf50", ) if self.counter != self.n_iter: html_string += template_string.format( label=self.current_label, foreground_color="white", background_color="#2196f3" if self._active else "#f44336", ) for label in self.unstarted_labels: html_string += template_string.format( label=label, foreground_color="#aaa", background_color="white", ) if self.postfix != "": html_string += f""" <div style="margin-left: 8px; flex-shrink: 0; font-family: var(--jp-code-font-family, monospace); font-size: var(--jp-code-font-size, 13px);"> {html.escape(self.postfix)} </div> """ html_string += "</div>" return html_string def __enter__(self): super().__enter__() self.reset() if self._displays is None: self._displays = [_create_display(self, self._position)] self.refresh() return self def __exit__(self, *args) -> bool: ret_val = super().__exit__() if self.counter != self.n_iter: self.refresh() return ret_val
[docs]class FileDisplay: """Use file which supports ANSI escape sequences as an updatable display.""" CURSOR_UP = "\x1b[A" """ANSI escape sequence to move cursor up one line.""" CURSOR_DOWN = "\x1b[B" """ANSI escape sequence to move cursor down one line.""" def __init__( self, position: tuple[int, int] = (0, 1), file: Optional[TextIO] = None, ): r""" Args: position: Tuple specifying position of display line within a sequence lines with first entry corresponding to zero-indexed line and the second entry the total number of lines. file: File object to write updates to. Must support ANSI escape sequences `\x1b[A}` (cursor up) and `\\x1b[B` (cursor down) for manipulating write position. Defaults to `sys.stdout` if `None`. """ self._position = position self._file = file if file is not None else sys.stdout self._last_string_length = 0 if self._position[0] == 0: self._file.write("\n" * self._position[1]) self._file.flush() def _move_line(self, offset: int): self._file.write(self.CURSOR_DOWN * offset + self.CURSOR_UP * -offset) self._file.flush()
[docs] def update(self, obj): """Update display with string representation of object. Args: obj: Object to display. """ self._move_line(self._position[0] - self._position[1]) string = str(obj) self._file.write(f"{string: <{self._last_string_length}}\r") self._last_string_length = len(string) self._move_line(self._position[1] - self._position[0]) self._file.flush()
class _ProxySequenceProgressBar: """Proxy progress bar that outputs progress updates to a queue. Intended for communicating progress updates from a child to parent process when distributing tasks across multiple processes. """ def __init__(self, sequence: Collection, job_id: int, iter_queue: Queue): """ Args: sequence: Sequence to iterate over. Must be iterable _and_ have a defined length such that `len(sequence)` is valid. job_id: Unique integer identifier for progress bar amongst other progress bars sharing same `iter_queue` object. iter_queue: Shared queue object that progress updates are pushed to. """ self._sequence = sequence self._n_iter = len(sequence) self._job_id = job_id self._iter_queue = iter_queue def __len__(self) -> int: return self._n_iter def __enter__(self): self._iter_queue.put((self._job_id, 0, None)) return self def __exit__(self, *args) -> bool: return False def __iter__(self) -> Generator[tuple[Any, dict[str, float]], None, None]: for i, val in enumerate(self._sequence): iter_dict = {} yield val, iter_dict self._iter_queue.put((self._job_id, i + 1, iter_dict))