Source code for mpl_template.template

import copy
import io
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
from matplotlib import axes, figure, gridspec

__all__ = ["insert_image", "Template"]


try:
    import requests
except ImportError:  # pragma: no cover
    if TYPE_CHECKING:
        import requests
    else:
        requests = None


try:
    from PIL import Image
    from PIL.ExifTags import TAGS
    from PIL.Image import Resampling, Transpose  # type:ignore
except ImportError:  # pragma: no cover
    if TYPE_CHECKING:
        from PIL import Image
        from PIL.ExifTags import TAGS
        from PIL.Image import Resampling, Transpose  # type:ignore
    else:
        TAGS, Image = None, None


def _calc_extents(size: int, scale: float) -> Tuple[float, float]:
    """
    Calculates the view limits needed to see a
    centered image at the desired scale.
    E.g.
    Before: image spans whole width
    scale = 0.5
    |________________________|
    0                        size

    After: upper and lower boundaries are
    identified to make the original size
    appear scaled and centered.

    |......____________......|
    lower  0           size  upper

    Parameters
    ----------
    size : int
        Number of pixels in x or y dimension
    scale : float
        The relative scale of the desired output.
    """
    lower = (1 - scale) * size * (1 / scale) / 2.0
    upper = size + lower
    return -lower, upper


def _image_path_or_url(path: str) -> Union[str, io.BytesIO]:
    """
    Prepares image path or url for loading into format ready for
    loading by PIL.Image.open()

    Parameters
    ----------
    path : string
        filepath or url to an image file. For images from the web, only
        the file extensions '.png', '.jpg', and '.jpeg' are supported.

    Returnns
    --------
    filepath for images from a file or
    io.BytesIO object for web images
    """

    ext = Path(path).suffix.lstrip(".")
    valid_types = ["png", "jpg", "jpeg"]

    if ext not in valid_types:
        raise ValueError("Supported image types include: {}".format(valid_types))

    if "http" in path:
        if requests is None:  # pragma: no cover
            raise ImportError(
                "the `requests` library is required to load images via url."
            )
        r = requests.get(path)
        img_file_obj = io.BytesIO(r.content)
        return img_file_obj

    else:
        return str(Path(path).resolve())


def _apply_exif_rotation(im):
    """Apply exif rotation tag to a PIL image object

    Parameters
    ----------
    im : PIL.Image
        image object that may contain rotation data

    Returns
    -------
    PIL.Image
    """
    if TAGS is None or Image is None:  # pragma: no cover
        raise ImportError("The `pillow` library is required to manipulate images.")
    i = im.copy()

    try:
        exif = {TAGS.get(tag): value for tag, value in im._getexif().items()}

        # this section adapted from the following SO post:
        # https://stackoverflow.com/a/1608846/7486933

        orientation = exif.get("Orientation")
        if orientation == 1:
            # Nothing
            i = im.copy()
        elif orientation == 2:
            # Vertical Mirror
            i = im.transpose(Transpose.FLIP_LEFT_RIGHT)
        elif orientation == 3:
            # Rotation 180°
            i = im.transpose(Transpose.ROTATE_180)
        elif orientation == 4:
            # Horizontal Mirror
            i = im.transpose(Transpose.FLIP_TOP_BOTTOM)
        elif orientation == 5:
            # Horizontal Mirror + Rotation 90° CCW
            i = im.transpose(Transpose.FLIP_TOP_BOTTOM).transpose(Transpose.ROTATE_90)
        elif orientation == 6:
            # Rotation 270°
            i = im.transpose(Transpose.ROTATE_270)
        elif orientation == 7:
            # Horizontal Mirror + Rotation 270°
            i = im.transpose(Transpose.FLIP_TOP_BOTTOM).transpose(Transpose.ROTATE_270)
        elif orientation == 8:
            # Rotation 90°
            i = im.transpose(Transpose.ROTATE_90)
        else:  # pragma: no cover
            raise Exception("Invalid EXIF Orientation Value")
        return i

    except (AttributeError, KeyError, IndexError):
        return im


[docs] def insert_image( ax: axes.Axes, image_path: str, scale: float = 1.0, dpi: float = 300.0, expand: bool = False, **kwargs, ) -> axes.Axes: """ Centers an image within an axes object Parameters ---------- ax : ``matplotlib.Axes`` the axes into which the image should be inserted. image_path : str Path to an existing image file. scale : float, optional The relative scale of the desired output. Values should be positive floats. E.g. scale = 2 will double the image's size relative to the given matplotlib.Axes object. scale = 0.5 will scale the image to half of the given matplotlib.Axes object. dpi : int, optonal (default=300) The Dots (pixel) Per Inch of the image. expand : bool, optional (default=False) If true, the image will expand to fill the axes in which it is embedded. Use expand = True if the boundary of the enclosing axes is the desired crop boundary. Use expand = False if the image should be scaled in-place with it's original aspect ratio. This option only affects images that have been zoomed (scale > 1). kwargs : keyword arguments to pass to the figure.add_axes() constructor. Returns ------- ax : matplotlib.Axes Notes ----- Use scale parameter to zoom image relative to axes boundary. Examples -------- Images are inserted centered, scaled, and filling the parent axes object. Zoomed out 2x with ``scale=0.5``. Note that the enclosing axes is square, and that the inserted axes is a landscape rectangle that matches the shape of the source file. .. plot:: :context: reset :include-source: True >>> import matplotlib.pyplot as plt >>> from mpl_template import insert_image >>> file = "img/polar_bar_demo.png" >>> fig, ax = plt.subplots(figsize=(3, 3)) >>> img_ax = insert_image(ax, file, scale=0.5) Zoomed in 2x with ``scale=2``. Note that the use of the ``expand`` kwarg causes the zoomed image to fill the enclosing square axes object. .. plot:: :context: reset :include-source: True >>> import matplotlib.pyplot as plt >>> from mpl_template import insert_image >>> file = "img/polar_bar_demo.png" >>> fig, ax = plt.subplots(figsize=(3, 3)) >>> img_ax = insert_image(ax, file, scale=2, expand=True) """ if TAGS is None or Image is None: # pragma: no cover raise ImportError("The `pillow` library is required to manipulate images.") if "xticks" not in kwargs: kwargs["xticks"] = [] if "yticks" not in kwargs: kwargs["yticks"] = [] if "zorder" not in kwargs: kwargs["zorder"] = 1 imgaxes = ax.figure.add_axes(ax.get_position(), **kwargs) bbox = ax.get_window_extent().transformed( ax.get_figure().dpi_scale_trans.inverted() # type: ignore ) width, height = bbox.width, bbox.height width *= dpi height *= dpi img_fp_or_obj = _image_path_or_url(image_path) with Image.open(img_fp_or_obj) as image: image = _apply_exif_rotation(image) image = image.convert("RGBA") wpx, hpx = image.size aspect_im = wpx / hpx aspect_ax = width / height adjy = (hpx / scale) / 2 adjx = (wpx / scale) / 2 if scale > 1: if expand and (aspect_im < aspect_ax): adjx = (width / height) * adjy if expand and (aspect_im > aspect_ax): adjy = (height / width) * adjx image = image.crop( ( int(wpx / 2 - adjx), int(hpx / 2 - adjy), int(wpx / 2 + adjx), int(hpx / 2 + adjy), ) ) if not expand: if width >= height: width = int(wpx * (height / hpx)) else: height = int(hpx * (width / wpx)) image = image.resize((int(width), int(height)), Resampling.BICUBIC) else: if width >= height: width = int(wpx * (height / hpx)) else: height = int(hpx * (width / wpx)) image = image.resize( (int(width * scale), int(height * scale)), Resampling.LANCZOS ) left, right = _calc_extents(image.size[0], scale) imgaxes.set_xlim(left, right) bottom, top = reversed(_calc_extents(image.size[1], scale)) imgaxes.set_ylim(bottom, top) imgaxes.imshow(image, aspect="equal") return imgaxes
def _get_default_tb_spans( rows: Tuple[int, ...], cols: Tuple[int, ...] ) -> List[Dict[str, Any]]: spans = [ {"span": [0, rows[0], 0, sum(cols)]}, {"span": [rows[0], rows[0] + rows[1], 0, cols[0] + cols[1]]}, {"span": [rows[0] + rows[1], sum(rows), 0, cols[0]]}, {"span": [rows[0] + rows[1], sum(rows), cols[0], cols[0] + cols[1]]}, {"span": [rows[0], sum(rows), cols[0] + cols[1], sum(cols)]}, ] return spans def _validate_margins( margins: Optional[Tuple[int, int, int, int]] = None, base=10.0 ) -> Tuple[int, int, int, int]: if margins is None: _m = int(0.4 * base) margins = (_m, _m, _m, _m) elif len(margins) != 4 or not all(isinstance(x, int) for x in margins): raise ValueError("`margins` must contain four integers") return margins
[docs] class Template: """ Class to construct a report figure template using matplotlib which includes a figure border, script path, and title block. Parameters ---------- scriptname : str Path to the script or notebook that is creating the figure. margins : tuple of int, optional (default = (4, 4, 4, 4) A length-4 tuple specifying the left, right, top, and bottom margins of on the page, respectively titleblock_content : list of dicts, optional Title block elements where each element is itself a dictionary with a `span` keys that determines which rows and columns the each element will occupy in the titleblock. E.g. :: tbk = [ { 'name': 'Title', #`text` must be a dict or a list of dicts 'text': [ { 's': 'Figure Title', 'weight': 'bold', }, { 's': 'Figure Subtitle', 'weight': 'light', }, ], #`image` must refer to dict with `path` key (required), # and optional keys `scale` and `expand` which default # to 1 and False, respectively. 'image': { 'path': 'img//logo.png', 'scale': 1, }, #`span` must be a list of integers for the # gridspec columns that the titleblock element will # span in tenths of an inch. The following span # will give a titleblock element that is 0.8 inches tall # and 3.2 inches wide. It will be the top left element # of the block because its height and width begin at zero. 'span': [0, 8, 0, 32], }, { # specify keys for next tbk element }, ] titleblock_cols : tuple of int, optional (default=(16, 16, 8)) The specification (in tenths of an inch) of the rulers for each column in the title block. titleblock_rows : tuple of int, optional (default=(8, 5, 3)) The specification (in tenths of an inch) of the rulers for each rows in the title block. draft : bool, optional (default=True) Toggles the inclusion of a draft watermark. base : int, optional (default=10) Number of gridspec rows and columns per inch. dpi : int, optional (default=300) Resolution of the final figure in dots per inch. **figkwargs Additional keyword arguments passed to ``plt.figure`` Examples -------- To produce an empty figure containing a border object, and 5 title block objects: .. plot:: :context: reset :include-source: True >>> from mpl_template import Template >>> report_fig = Template(figsize=(8.5, 11), scriptname="path/to/script.py") >>> fig = report_fig.blank() """ def __init__( self, margins: Optional[Tuple[int, int, int, int]] = None, titleblock_content=None, titleblock_cols=None, titleblock_rows=None, scriptname=None, draft=True, base=None, dpi: float = 300, **figkwargs, ): if scriptname is None: # pragma: no cover raise Exception("Must enter name of calling script for footnote") self.script_name = scriptname self._base = 10.0 if base is None else float(base) self._margins = _validate_margins(margins, self._base) self.left, self.right, self.top, self.bottom = self.margins if titleblock_cols is None: # pragma: no branch titleblock_cols = ( int(1.6 * self._base), int(1.6 * self._base), int(0.8 * self._base), ) if titleblock_rows is None: # pragma: no branch titleblock_rows = ( int(0.8 * self._base), int(0.5 * self._base), int(0.3 * self._base), ) self.default_spans = _get_default_tb_spans(titleblock_rows, titleblock_cols) self.t_w = sum(titleblock_cols) self.t_h = sum(titleblock_rows) self.is_draft = draft self._fig = None self._gsfig = None self._watermark = None self._path_text = None self._gstitleblock = None self._gstitleblock_subspec = None self._fig_options = figkwargs self._fig_options["dpi"] = dpi self._titleblock_content = titleblock_content @property def base(self): return self._base @property def margins(self): return self._margins @margins.setter def margins(self, value): # pragma: no cover self._margins = _validate_margins(value) self.left, self.right, self.top, self.bottom = self._margins @property def titleblock_content(self): if self._titleblock_content is None: self._titleblock_content = self.default_spans return self._titleblock_content @titleblock_content.setter # pragma: no cover def titleblock_content(self, value): self._titleblock_content = value @property def fig(self) -> figure.Figure: if self._fig is None: self._fig = plt.figure(**self._fig_options) if self.is_draft: # pragma: no branch self.add_watermark() return self._fig @fig.setter def fig(self, value): # pragma: no cover self._fig = value @property def gsfig(self): if self._gsfig is None: # pragma: no branch row = int(self.fig.get_figheight() * self.base) col = int(self.fig.get_figwidth() * self.base) self._gsfig = gridspec.GridSpec( row, col, left=0, right=1, bottom=0, top=1, wspace=0, hspace=0 ) return self._gsfig @gsfig.setter def gsfig(self, value): # pragma: no cover self._gsfig = value @property def watermark(self): if self._watermark is None: self._watermark = self.add_watermark() return self._watermark @watermark.setter def watermark(self, value): self._watermark = value @property def path_text(self): if self._path_text is None: self._path_text = str(Path.cwd() / self.script_name) return self._path_text @path_text.setter def path_text(self, value): self._path_text = value @property def gstitleblock(self): if self._gstitleblock is None: self._gstitleblock = self.gsfig[ -(self.bottom + self.t_h) or None : -self.bottom or None, -(self.right + self.t_w) or None : -self.right or None, ] return self._gstitleblock @gstitleblock.setter def gstitleblock(self, value): self._gstitleblock = value self._gstitleblock_subspec = None @property def gstitleblock_subspec(self): if self._gstitleblock_subspec is None: self._gstitleblock_subspec = gridspec.GridSpecFromSubplotSpec( self.t_h, self.t_w, self.gstitleblock, wspace=0.0, hspace=0.0, ) return self._gstitleblock_subspec @gstitleblock_subspec.setter def gstitleblock_subspec(self, value): # pragma: no cover self._gstitleblock_subspec = value def add_frame(self): fheight = self.base * self.fig.get_figheight() fwidth = self.base * self.fig.get_figwidth() _left = self.left / fwidth _right = self.right / fwidth _bottom = self.bottom / fheight _top = self.top / fheight rect = [_left, _bottom, 1 - (_left + _right), 1 - (_bottom + _top)] frame = self.fig.add_axes( rect, zorder=100, facecolor="none", xticks=[], yticks=[], label="frame" ) return frame def add_watermark(self, text=None): if text is None: # pragma: no branch text = "DRAFT" x = (0.5 * self.base) / (self.base * self.fig.get_figwidth()) y = 1 - self.top / (self.base * self.fig.get_figheight()) watermark = self.fig.text( x, y, text, fontsize=24, color="r", fontname="Arial", fontweight="bold", zorder=1000, horizontalalignment="left", verticalalignment="center", ) self.watermark = watermark return self.watermark def add_titleblock(self): axlist = [] for i, dct in enumerate(self.titleblock_content): if "span" in list(dct.keys()): r0, r, c0, c = dct["span"] else: r0, r, c0, c = self.default_spans[i]["span"] if "name" in list(dct.keys()): label = dct["name"] else: label = "b_{}".format(i) ax = self.fig.add_subplot( self.gstitleblock_subspec[r0:r, c0:c], label=label, zorder=100, facecolor="none", xticks=[], yticks=[], aspect="equal", adjustable="datalim", ) axlist.append(ax) return axlist def add_page(self): # pragma: no cover ax = self.fig.add_axes( [0, 0, 1, 1], zorder=1000, facecolor="none", xticks=[], yticks=[], label="page", ) return ax def add_path_text(self): x = self.left / (self.base * self.fig.get_figwidth()) y = abs( (self.bottom - 0.15 * self.base) / (self.base * self.fig.get_figheight()) ) text = "Source: " + self.path_text textobj = self.fig.text( x, y, text, fontsize=5, horizontalalignment="left", verticalalignment="center", ) return textobj def populate_titleblock(self): for ax in self.fig.get_axes(): label = ax.get_label() for i, dct in enumerate(self.titleblock_content): name = dct.get("name", "b_{}".format(i)) content = dct.get("text") image = dct.get("image") if name == label: if content is not None: if isinstance(content, dict): content = [content] if isinstance(content, list): for elem in content: kwargs = copy.deepcopy(elem) if "transform" not in kwargs: # pragma: no branch kwargs["transform"] = ax.transAxes ax.text(**kwargs) else: # pragma: no cover raise ValueError( "`text` key must map to dict or list of dicts" ) if image is not None: scale = image.get("scale", 1) expand = image.get("expand", False) img_ax = insert_image( ax, image["path"], scale=scale, dpi=ax.get_figure().get_dpi(), expand=expand, ) img_ax.set_label("img_b_{}".format(i)) img_ax.axis("off") def setup_figure(self) -> figure.Figure: _ = self.add_frame() _ = self.add_titleblock() _ = self.add_path_text() self.populate_titleblock() return self.fig def blank(self) -> figure.Figure: self.add_frame() for ax in self.add_titleblock(): ax.text( 0.5, 0.5, '"{}"'.format(ax.get_label()), va="center", ha="center", size=12, ) self.watermark.remove() return self.fig