Source code for postpic.plotting.plotter_matplotlib

#
# This file is part of postpic.
#
# postpic is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# postpic is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with postpic. If not, see <http://www.gnu.org/licenses/>.
#
# Stephan Kuschel 2014-2017
# Alexander Blinne, 2017
"""
This package provides the MatplotlibPlotter Class.

This Class can be used to plot Field Objects using the matplotlib interface.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np
import warnings


__all__ = ['MatplotlibPlotter']


[docs]class MatplotlibPlotter(object): ''' Provides Methods to modify figures and axes objects for convenient plotting. It also autogenerates savenames and annotates the plot if a reader is given. A reader can be a dumpreader or a simulationreder. ''' import matplotlib.ticker axesformatterx = matplotlib.ticker.ScalarFormatter() axesformatterx.set_powerlimits((-2, 3)) axesformattery = matplotlib.ticker.ScalarFormatter() axesformattery.set_powerlimits((-2, 3)) from matplotlib.colors import LinearSegmentedColormap efieldcdict = {'red': ((0, 0, 0), (0.5, 1, 1), (1.0, 1, 1)), 'green': ((0, 0, 0), (0.5, 1, 1), (1, 0, 0)), 'blue': ((0, 1, 1), (0.5, 1, 1), (1, 0, 0))} symmap = LinearSegmentedColormap('EField', efieldcdict, 1024) def __init__(self, reader, outdir='./', autosave=False, project=None, ext='png', size_inches=(9, 7), dpi=160, facecolor=(1, 1, 1, 0.01), transparent=False): self._ext = ext self.autosave = autosave self.reader = reader self.outdir = outdir self._project = project self.size_inches = size_inches self.dpi = dpi self.facecolor = facecolor self.transparent = transparent self._savenamesused = [] def __len__(self): return len(self._savenamesused) @property def project(self): return self._project if self._project else ''
[docs] def savename(self, key, ext=None): if not ext: ext = self._ext name = self.project + '_' + self.reader.name + \ '_' + str(len(self._savenamesused)) + '_' + key name = name.replace('/', '_').replace(' ', '') name = self.outdir + name nametmp = name + '_%d' i = 0 while name in self._savenamesused: i = i + 1 name = nametmp % i self._savenamesused.append(name) # print name return name + '.' + ext
[docs] def lastsavename(self): ''' returns the last savenme. If there wasnt a last a new savename is created. ''' if len(self._savenamesused) == 0: return self.savename('lastsavename') else: return self._savenamesused[-1] return
[docs] def savefig(self, fig, key): savename = self.savename(key) fig.savefig(savename, dpi=self.dpi, facecolor=self.facecolor, transparent=self.transparent) return
[docs] @staticmethod def settext_fig(fig, title=None, ur=None, ur2=None, ul=None, ul2=None, center=None): if title: fig.suptitle(title) if ur: fig.text(0.92, 0.965, ur, ha='right') if ur2: fig.text(0.92, 0.93, ur2, ha='right') if ul: fig.text(0.03, 0.965, ul, horizontalalignment='left') if ul2: fig.text(0.03, 0.93, ul2, horizontalalignment='left') if center: fig.text(0.5, 0.87, center, horizontalalignment='center')
[docs] @staticmethod def settext_ax(ax, title=None, ur=None, ur2=None, ul=None, ul2=None, center=None): if title: ax.set_title(title) if ur2: ax.text(0.99, 1.01, ur2, ha='right', transform=ax.transAxes) if ur: ax.text(0.99, 0.97, ur, ha='right', transform=ax.transAxes) if ul2: ax.text(0.01, 1.01, ul2, horizontalalignment='left', transform=ax.transAxes) if ul: ax.text(0.01, 0.97, ul, horizontalalignment='left', transform=ax.transAxes) if center: ax.text(0.5, 0.87, center, horizontalalignment='center', transform=ax.transAxes)
[docs] @staticmethod def annotate(figorax, title=None, time=None, step=None, project=None, dump=None, infostring=None, infos=None): ur = '' if time: if isinstance(time, float): ur = '{:.1f} fs'.format(1e15 * time) else: ur = str(time) if step: if isinstance(step, (int, float)): ur += ', step: {:6.0f}'.format(step) else: ur += ' ' + str(step) ul = project ul2 = dump ur2 = infostring center = None if infos == [] or infos == [''] else infos import matplotlib func = MatplotlibPlotter.settext_ax if isinstance(figorax, matplotlib.axes.Axes) \ else MatplotlibPlotter.settext_fig func(figorax, title, ur, ur2, ul, ul2, center) return
[docs] @staticmethod def annotate_fromfield(figorax, field): MatplotlibPlotter.annotate(figorax, title=field.label, infostring=field.infostring, infos=field.infos) return
[docs] @staticmethod def annotate_fromreader(figorax, reader): try: MatplotlibPlotter.annotate(figorax, time=reader.time(), step=reader.timestep(), dump=reader.name) except AttributeError: MatplotlibPlotter.annotate(figorax) return
[docs] @staticmethod def symmetricclimaximage(aximage): """ symmetrize the clim around 0. """ bound = max(abs(np.asarray(aximage.get_clim()))) aximage.set_clim(-bound, bound) return
[docs] @staticmethod def symmetricclim(ax): """ symmetrize the clim around 0. """ MatplotlibPlotter.symmetricclimaximage(ax.images[0]) return
[docs] @staticmethod def addaxislabels(ax, field): if len(field.axes) > 0: ax.set_xlabel(field.axes[0].label) if len(field.axes) > 1: ax.set_ylabel(field.axes[1].label) return
[docs] @staticmethod def addField1d(ax, field, log10plot=True, xlim=None, ylim=None, scaletight=None): field = field.squeeze() assert field.dimensions == 1, 'Field needs to be 1 dimensional' ax.plot(field.grid, field.matrix, label=field.label) ax.xaxis.set_major_formatter(MatplotlibPlotter.axesformatterx) ax.yaxis.set_major_formatter(MatplotlibPlotter.axesformattery) if log10plot and ((field.matrix < 0).sum() == 0) \ and any(field.matrix > 0): ax.set_yscale('log') # sets the axis to log scale AND overrides # our previously set axesformatter to the default # matplotlib.ticker.LogFormatterMathtext. MatplotlibPlotter.addaxislabels(ax, field) ax.autoscale(tight=scaletight) if xlim is not None: ax.set_xlim(xlim) if ylim is not None: ax.set_ylim(ylim) return ax
[docs] @staticmethod def addFields1d(ax, *fields, **kwargs): # only write infos to Image if all infos of all fields are equal. clearinfos = not all([str(f.infos) == str(fields[0].infos) for f in fields]) infostrings = [] for field in fields: if field.dimensions <= 0: continue infostrings.append(field.infostring) if clearinfos: field.infos = [] MatplotlibPlotter.addField1d(ax, field, **kwargs) MatplotlibPlotter.annotate_fromfield(ax, field) MatplotlibPlotter.annotate(ax, infostring=str(infostrings)) handles, labels = ax.get_legend_handles_labels() ax.legend(handles, labels) return
# add lineouts to 2D-plots @staticmethod def _addxlineout(ax0, m, extent, log10=False): ax = ax0.twinx() lout = m.mean(axis=0) if log10: lout = np.log10(lout) x = np.linspace(extent[0], extent[1], len(lout)) ax.plot(x, lout, 'k', lw=1) ax.autoscale(tight=True) return ax @staticmethod def _addylineout(ax0, m, extent, log10=False): ax = ax0.twiny() # ax.set_xlim(ax.get_xlim()[::-1]) lout = m.mean(axis=1) if log10: lout = np.log10(lout) x = np.linspace(extent[2], extent[3], len(lout)) ax.plot(lout, x, 'k', lw=1) ax.autoscale(tight=True) # ax.spines['top'].set_position(('axes',0.9)) return ax
[docs] @staticmethod def addField2d(figax, field, log10plot=True, interpolation='none', contourlevels=np.array([]), saveandclose=True, xlim=None, ylim=None, clim=None, savecsv=False, lineoutx=False, lineouty=False, **kwargs): field = field.squeeze() (fig, ax) = figax assert field.dimensions == 2 or (field.dimensions == 3 and field.shape[2] in [3, 4]), \ 'Field needs to be 2 dimensional' color_image = field.dimensions == 3 maximum = None if color_image: maximum = np.max(field.matrix) field = field/maximum ax.xaxis.set_major_formatter(MatplotlibPlotter.axesformatterx) ax.yaxis.set_major_formatter(MatplotlibPlotter.axesformattery) if log10plot and not any(field.matrix.flatten() < 0) and \ any(field.matrix.flatten() > 0) and not color_image: if 'cmap' not in kwargs: kwargs['cmap'] = 'jet' if 'aspect' not in kwargs: kwargs['aspect'] = 'auto' if all(field.islinear()): im = ax.imshow(np.log10(field.matrix.T), origin='lower', extent=field.extent, interpolation=interpolation, **kwargs) elif not color_image: x, y = [ax.grid_node for ax in field.axes] if 'aspect' in kwargs: del kwargs['aspect'] im = ax.pcolormesh(x, y, np.log10(field.matrix.T), **kwargs) else: raise ValueError("color images with non-linear axes not supported by this " "function.") fig.colorbar(im, format='%3.1f') if clim: im.set_clim(clim) else: log10plot = False if 'cmap' not in kwargs: kwargs['cmap'] = MatplotlibPlotter.symmap if 'aspect' not in kwargs: kwargs['aspect'] = 'auto' if all(field.islinear()): im = ax.imshow(np.swapaxes(field.matrix, 0, 1), origin='lower', extent=field.extent[:4], interpolation=interpolation, **kwargs) elif not color_image: x, y = [ax.grid_node for ax in field.axes] if 'aspect' in kwargs: del kwargs['aspect'] im = ax.pcolormesh(x, y, field.matrix.T, **kwargs) else: raise ValueError("color images with non-linear axes not supported by this " "function.") if clim: im.set_clim(clim) else: MatplotlibPlotter.symmetricclim(ax) if not color_image: fig.colorbar(im, format='%6.0e') if contourlevels.size != 0: # Draw contour lines ax.contour(field.matrix.T, contourlevels, hold='on', extent=field.extent) if xlim is not None: ax.set_xlim(xlim) if ylim is not None: ax.set_ylim(ylim) if lineoutx: MatplotlibPlotter._addxlineout(ax, field.matrix.T, field.extent, log10=log10plot) if lineouty: MatplotlibPlotter._addylineout(ax, field.matrix.T, field.extent, log10=log10plot) MatplotlibPlotter.addaxislabels(ax, field) MatplotlibPlotter.annotate_fromfield(ax, field) return
def _plotfinalize(self, fig): self.annotate_fromreader(fig, self.reader) return
[docs] def plotFields1d(self, *fields, **kwargs): import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(1, 1, 1) name = kwargs.pop('name', fields[0].name) MatplotlibPlotter.addFields1d(ax, *fields, **kwargs) self._plotfinalize(fig) self.annotate(fig, project=self.project) self.annotate(ax, infostring=str([f.infostring for f in fields])) fig.set_size_inches(*self.size_inches) if self.autosave: self.savefig(fig, name) plt.close(fig) fig = None if 'savecsv' in kwargs and kwargs['savecsv']: for field in fields: if field.dimensions == 0: continue field.exporttocsv(self.lastsavename() + field.label + '.csv') return fig
[docs] def plotField2d(self, field, name=None, **kwargs): import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(1, 1, 1) MatplotlibPlotter.addField2d((fig, ax), field, **kwargs) self._plotfinalize(fig) self.annotate(fig, project=self.project) self.annotate(ax, infostring=field.infostring) fig.set_size_inches(*self.size_inches) if self.autosave: self.savefig(fig, name if name else field.name) plt.close(fig) fig = None return fig
[docs] def plotField(self, field, autoreduce=True, maxlen=6000, name=None, **kwargs): ''' This is the main method, that should be used for plotting. ''' field = field.squeeze() if autoreduce: field.autoreduce(maxlen=maxlen) if field is None: ret = self._skipplot('none') elif field.dimensions <= 0: ret = self._skipplot(name if name else field.name) elif field.dimensions == 1: if name: kwargs.update({'name': name}) ret = self.plotFields1d(field, **kwargs) elif field.dimensions == 2 or (field.dimensions == 3 and field.shape[2] in [3, 4]): ret = self.plotField2d(field, name, **kwargs) else: raise Exception('3D not implemented') return ret
def _skipplot(self, key): import matplotlib.pyplot as plt fig = plt.figure() plt.figtext(0.5, 0.5, 'No data available.', ha='center') fig.set_size_inches(*self.size_inches) if self.autosave: self.savefig(fig, key) plt.close(fig) fig = None print('Skipped Plot.') return fig
[docs] def plotFields(self, *fields, **kwargs): ret = [self.plotField(field, **kwargs) for field in fields] return ret
[docs] def plotallderived(self, dumpreader): ''' plots all fields dumped. ''' try: derived = dumpreader.getderived() except AttributeError: return fields = dumpreader.createfieldsfromkeys(*derived) for f in fields: self.plotField(f) return