__all__ = ["insert_image", "Template"]
import os
import io
import copy
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
def _import_requests():
import requests
return requests
def _import_PIL_TAGS():
from PIL.ExifTags import TAGS
return TAGS
def _import_PIL_Image():
from PIL import Image
return Image
def _calc_extents(size, scale):
"""
Calculates the view limits needed to see a
centered image at the desired scale.
E.g.
Before: image spans whole width
scale = 0.5
|________________________|
0 size
After: upper and lower boundaries are
identified to make the original size
appear scaled and centered.
|......____________......|
lower 0 size upper
Parameters
----------
size : int
Number of pixels in x or y dimension
scale : float
The relative scale of the desired output.
"""
lower = (1 - scale) * size * (1 / scale) / 2.
upper = size + lower
return -lower, upper
def _image_path_or_url(path):
"""
Prepares image path or url for loading into format ready for
loading by PIL.Image.open()
Parameters
----------
path : string
filepath or url to an image file. For images from the web, only
the file extensions '.png', '.jpg', and '.jpeg' are supported.
Returnns
--------
filepath for images from a file or
io.BytesIO object for web images
"""
if ('http' in path):
requests = _import_requests()
valid_types = ['.png', '.jpg', '.jpeg']
if any(ftype in path for ftype in valid_types):
r = requests.get(path)
img_file_obj = io.BytesIO(r.content)
else:
raise ValueError(
"Supported web image types include: {}".format(valid_types))
return img_file_obj
else:
return os.path.realpath(path)
def _apply_exif_rotation(im):
"""Apply exif rotation tag to a PIL image object
Parameters
----------
im : PIL.Image
image object that may contain rotation data
Returns
-------
PIL.Image
"""
TAGS = _import_PIL_TAGS()
Image = _import_PIL_Image()
try:
exif = {TAGS.get(tag): value for tag, value in im._getexif().items()}
# this section adapted from the following SO post:
# https://stackoverflow.com/a/1608846/7486933
if 'Orientation' in exif.keys():
orientation = exif['Orientation']
if orientation == 1:
# Nothing
i = im.copy()
elif orientation == 2:
# Vertical Mirror
i = im.transpose(Image.FLIP_LEFT_RIGHT)
elif orientation == 3:
# Rotation 180°
i = im.transpose(Image.ROTATE_180)
elif orientation == 4:
# Horizontal Mirror
i = im.transpose(Image.FLIP_TOP_BOTTOM)
elif orientation == 5:
# Horizontal Mirror + Rotation 90° CCW
i = im.transpose(Image.FLIP_TOP_BOTTOM).transpose(
Image.ROTATE_90)
elif orientation == 6:
# Rotation 270°
i = im.transpose(Image.ROTATE_270)
elif orientation == 7:
# Horizontal Mirror + Rotation 270°
i = im.transpose(Image.FLIP_TOP_BOTTOM).transpose(
Image.ROTATE_270)
elif orientation == 8:
# Rotation 90°
i = im.transpose(Image.ROTATE_90)
else:
raise Exception('Invalid EXIF Orientation Value')
else:
i = im.copy()
return i
except (AttributeError, KeyError, IndexError):
return im
[docs]def insert_image(ax, image_path, scale=1, dpi=300, expand=False, **kwargs):
"""
Centers an image within an axes object
Parameters
----------
ax : ``matplotlib.Axes``
the axes into which the image should be inserted.
image_path : str
Path to an existing image file.
scale : float, optional
The relative scale of the desired output.
Values should be positive floats.
E.g. scale = 2 will double the image's size
relative to the given matplotlib.Axes object.
scale = 0.5 will scale the image to half of the
given matplotlib.Axes object.
dpi : int, optonal (default=300)
The Dots (pixel) Per Inch of the image.
expand : bool, optional (default=False)
If true, the image will expand to fill the axes
in which it is embedded. Use expand = True if the
boundary of the enclosing axes is the desired crop boundary.
Use expand = False if the image should be scaled in-place
with it's original aspect ratio. This option only affects
images that have been zoomed (scale > 1).
kwargs : keyword arguments to pass to the figure.add_axes()
constructor.
Returns
-------
ax : matplotlib.Axes
Notes
-----
Use scale parameter to zoom image relative to axes boundary.
Examples
--------
Images are inserted centered, scaled, and filling the parent
axes object.
Zoomed out 2x with ``scale=0.5``. Note that the enclosing axes is
square, and that the inserted axes is a landscape rectangle that
matches the shape of the source file.
.. plot::
:context: reset
:include-source: True
>>> import matplotlib.pyplot as plt
>>> from mpl_template import insert_image
>>> file = "img/polar_bar_demo.png"
>>> fig, ax = plt.subplots(figsize=(3, 3))
>>> img_ax = insert_image(ax, file, scale=0.5)
Zoomed in 2x with ``scale=2``. Note that the use of the ``expand``
kwarg causes the zoomed image to fill the enclosing square axes object.
.. plot::
:context: reset
:include-source: True
>>> import matplotlib.pyplot as plt
>>> from mpl_template import insert_image
>>> file = "img/polar_bar_demo.png"
>>> fig, ax = plt.subplots(figsize=(3, 3))
>>> img_ax = insert_image(ax, file, scale=2, expand=True)
"""
Image = _import_PIL_Image()
if 'xticks' not in kwargs:
kwargs['xticks'] = []
if 'yticks' not in kwargs:
kwargs['yticks'] = []
if 'zorder' not in kwargs:
kwargs['zorder'] = 1
imgaxes = ax.figure.add_axes(ax.get_position(), **kwargs)
bbox = ax.get_window_extent().transformed(
ax.get_figure().dpi_scale_trans.inverted())
width, height = bbox.width, bbox.height
width *= dpi
height *= dpi
img_fp_or_obj = _image_path_or_url(image_path)
with Image.open(img_fp_or_obj) as image:
image = _apply_exif_rotation(image)
image = image.convert('RGBA')
wpx, hpx = image.size
aspect_im = wpx / hpx
aspect_ax = width / height
adjy = (hpx / scale) / 2
adjx = (wpx / scale) / 2
if scale > 1:
if expand and (aspect_im < aspect_ax):
adjx = (width / height) * adjy
if expand and (aspect_im > aspect_ax):
adjy = (height / width) * adjx
image = image.crop(
(int(wpx / 2 - adjx),
int(hpx / 2 - adjy),
int(wpx / 2 + adjx),
int(hpx / 2 + adjy)
)
)
if expand:
image = image.resize((int(width), int(height)), Image.BICUBIC)
else:
if width >= height:
width = int(wpx * (height / hpx))
else:
height = int(hpx * (width / wpx))
image = image.resize((int(width), int(height)), Image.BICUBIC)
else:
if width >= height:
width = int(wpx * (height / hpx))
else:
height = int(hpx * (width / wpx))
image = image.resize(
(int(width * scale), int(height * scale)), Image.LANCZOS)
imgaxes.set_xlim(_calc_extents(image.size[0], scale))
imgaxes.set_ylim(reversed(_calc_extents(image.size[1], scale)))
imgaxes.imshow(image, aspect='equal')
return imgaxes
def _get_default_tb_spans(rows, cols):
spans = [
{'span': [0, rows[0], 0, sum(cols)]},
{'span': [rows[0], rows[0] + rows[1], 0, cols[0] + cols[1]]},
{'span': [rows[0] + rows[1], sum(rows), 0, cols[0]]},
{'span': [rows[0] + rows[1], sum(rows), cols[0], cols[0] + cols[1]]},
{'span': [rows[0], sum(rows), cols[0] + cols[1], sum(cols)]},
]
return spans
def _validate_margins(margins):
if margins is None:
margins = (4, 4, 4, 4)
elif (len(margins) != 4 or not all(isinstance(x, int) for x in margins)):
raise ValueError("`margins` must contain four integers")
return margins
[docs]class Template(object):
"""
Class to construct a report figure template using matplotlib
which includes a figure border, script path, and title block.
Parameters
----------
scriptname : str
Path to the script or notebook that is creating the figure.
margins : tuple of int, optional (default = (4, 4, 4, 4)
A length-4 tuple specifying the left, right, top, and bottom
margins of on the page, respectively
titleblock_content : list of dicts, optional
Title block elements where each element is itself a
dictionary with a `span` keys that determines which rows and
columns the each element will occupy in the titleblock.
E.g. ::
tbk = [
{
'name': 'Title',
#`text` must be a dict or a list of dicts
'text': [
{
's': 'Figure Title',
'weight': 'bold',
},
{
's': 'Figure Subtitle',
'weight': 'light',
},
],
#`image` must refer to dict with `path` key (required),
# and optional keys `scale` and `expand` which default
# to 1 and False, respectively.
'image': {
'path': 'img//logo.png',
'scale': 1,
},
#`span` must be a list of integers for the
# gridspec columns that the titleblock element will
# span in tenths of an inch. The following span
# will give a titleblock element that is 0.8 inches tall
# and 3.2 inches wide. It will be the top left element
# of the block because its height and width begin at zero.
'span': [0, 8, 0, 32],
},
{ # specify keys for next tbk element
},
]
titleblock_cols : tuple of int, optional (default=(16, 16, 8))
The specification (in tenths of an inch) of the rulers for
each column in the title block.
titleblock_rows : tuple of int, optional (default=(8, 5, 3))
The specification (in tenths of an inch) of the rulers for
each rows in the title block.
draft : bool, optional (default=True)
Toggles the inclusion of a draft watermark.
dpi : int, optional (default=300)
Resolution of the final figure in dots per inch.
**figkwargs
Additional keyword arguments passed to ``plt.figure``
Examples
--------
To produce an empty figure containing a border object,
and 5 title block objects:
.. plot::
:context: reset
:include-source: True
>>> from mpl_template import Template
>>> report_fig = Template(figsize=(8.5, 11), scriptname="path/to/script.py")
>>> fig = report_fig.blank()
"""
def __init__(self,
margins=None,
titleblock_content=None,
titleblock_cols=None,
titleblock_rows=None,
scriptname=None,
draft=True,
dpi=300,
**figkwargs):
if scriptname is None:
raise Exception('Must enter name of calling script for footnote')
self.script_name = scriptname
self._margins = _validate_margins(margins)
self.left, self.right, self.top, self.bottom = self.margins
if titleblock_cols is None:
titleblock_cols = (16, 16, 8)
if titleblock_rows is None:
titleblock_rows = (8, 5, 3)
self.default_spans = _get_default_tb_spans(
titleblock_rows, titleblock_cols)
self.t_w = sum(titleblock_cols)
self.t_h = sum(titleblock_rows)
self.is_draft = draft
self._fig = None
self._gsfig = None
self._watermark = None
self._path_text = None
self._gstitleblock = None
self._gstitleblock_subspec = None
self._fig_options = figkwargs
self._fig_options['dpi'] = dpi
self._titleblock_content = titleblock_content
@property
def margins(self):
return self._margins
@margins.setter
def margins(self, value):
self._margins = _validate_margins(value)
self.left, self.right, self.top, self.bottom = self._margins
@property
def titleblock_content(self):
if self._titleblock_content is None:
self._titleblock_content = self.default_spans
return self._titleblock_content
@titleblock_content.setter
def titleblock_content(self, value):
self._titleblock_content = value
@property
def fig(self):
if self._fig is None:
self.fig = plt.figure(**self._fig_options)
if self.is_draft:
self.add_watermark()
return self._fig
@fig.setter
def fig(self, value):
self._fig = value
@property
def gsfig(self):
if self._gsfig is None:
row = int(self.fig.get_figheight() * 10)
col = int(self.fig.get_figwidth() * 10)
self._gsfig = gridspec.GridSpec(row, col, left=0, right=1, bottom=0,
top=1, wspace=0, hspace=0)
return self._gsfig
@gsfig.setter
def gsfig(self, value):
self._gsfig = value
@property
def watermark(self):
if self._watermark is None:
self._watermark = self.add_watermark()
return self._watermark
@watermark.setter
def watermark(self, value):
self._watermark = value
@property
def path_text(self):
if self._path_text is None:
self._path_text = os.path.join(os.getcwd(), self.script_name)
return self._path_text
@path_text.setter
def path_text(self, value):
self._path_text = value
@property
def gstitleblock(self):
if self._gstitleblock is None:
self._gstitleblock = self.gsfig[-(self.bottom + self.t_h) or None:-self.bottom or None,
-(self.right + self.t_w) or None:-self.right or None]
return self._gstitleblock
@gstitleblock.setter
def gstitleblock(self, value):
self._gstitleblock = value
self._gstitleblock_subspec = None
@property
def gstitleblock_subspec(self):
if self._gstitleblock_subspec is None:
self._gstitleblock_subspec = gridspec.GridSpecFromSubplotSpec(
self.t_h, self.t_w, self.gstitleblock, wspace=0.0, hspace=0.0,
)
return self._gstitleblock_subspec
@gstitleblock_subspec.setter
def gstitleblock_subspec(self, value):
self._gstitleblock_subspec = value
def add_frame(self):
_left = self.left / (10. * self.fig.get_figwidth())
_right = self.right / (10. * self.fig.get_figwidth())
_bottom = self.bottom / (10. * self.fig.get_figheight())
_top = self.top / (10. * self.fig.get_figheight())
rect = [_left, _bottom, 1 - (_left + _right), 1 - (_bottom + _top)]
frame = self.fig.add_axes(rect, zorder=100, facecolor='none',
xticks=[], yticks=[], label='frame')
return frame
def add_watermark(self, text=None):
if text is None:
text = 'DRAFT'
x = 5 / (10. * self.fig.get_figwidth())
y = 1 - self.top / (10. * self.fig.get_figheight())
watermark = self.fig.text(x, y, text, fontsize=24,
color='r', fontname='Arial',
fontweight='bold', zorder=1000,
horizontalalignment='left',
verticalalignment='center')
self.watermark = watermark
return self.watermark
def add_titleblock(self):
axlist = []
for i, dct in enumerate(self.titleblock_content):
if 'span' in list(dct.keys()):
r0, r, c0, c = dct['span']
else:
r0, r, c0, c = self.default_spans[i]['span']
if 'name' in list(dct.keys()):
label = dct['name']
else:
label = 'b_{}'.format(i)
ax = self.fig.add_subplot(self.gstitleblock_subspec[r0:r, c0:c],
label=label, zorder=100, facecolor='none',
xticks=[], yticks=[], aspect='equal',
adjustable='datalim')
axlist.append(ax)
return axlist
def add_page(self):
ax = self.fig.add_axes([0, 0, 1, 1], zorder=1000, facecolor='none',
xticks=[], yticks=[], label='page')
return ax
def add_path_text(self):
x = self.left / (10. * self.fig.get_figwidth())
y = abs((self.bottom - 1.5) / (10 * self.fig.get_figheight()))
text = 'Source: ' + self.path_text
textobj = self.fig.text(x, y, text, fontsize=5,
horizontalalignment='left',
verticalalignment='center')
return textobj
def populate_titleblock(self):
for ax in self.fig.get_axes():
label = ax.get_label()
for i, dct in enumerate(self.titleblock_content):
name = dct.get('name')
content = dct.get('text')
image = dct.get('image')
if name == label:
if content is not None:
if isinstance(content, dict):
content = [content]
if isinstance(content, list):
for elem in content:
kwargs = copy.deepcopy(elem)
if 'transform' not in kwargs:
kwargs['transform'] = ax.transAxes
ax.text(**kwargs)
else:
raise ValueError(
'`text` key must map to dict or list of dicts')
if image is not None:
scale = image.get('scale', 1)
expand = image.get('expand', False)
img_ax = insert_image(ax, image['path'],
scale=scale,
dpi=ax.get_figure().get_dpi(),
expand=expand)
img_ax.set_label('img_b_{}'.format(i))
img_ax.axis('off')
def setup_figure(self):
frame = self.add_frame()
block = self.add_titleblock()
path = self.add_path_text()
self.populate_titleblock()
return self.fig
def blank(self):
self.add_frame()
for ax in self.add_titleblock():
ax.text(0.5, 0.5, '"{}"'.format(ax.get_label()),
va="center", ha="center", size=12)
self.watermark.remove()
return self.fig