Source code for seirmo.plots._plot_from_numpy

#
# This file is part of seirmo (https://github.com/SABS-R3-Epidemiology/seirmo/)
# which is released under the BSD 3-clause license. See accompanying LICENSE.md
# for copyright notice and full license details.
#

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors


[docs] class ConfigurablePlotter: """ A figure class that visualises the population of each compartment over time Configurable to plot multiple subplots in one figure, with customised labels or colours Implements addfill() method to plot a shaded region between two datasets (I.e. when plotting confidence intervals) """ def __init__(self): pass
[docs] def begin(self, subplots_rows: int = 1, subplots_columns: int = 1): """ Begins creating a figure, with given number of subfigures Replaces init class so object can be reused""" if not isinstance(subplots_rows, int): raise TypeError("Number of rows of subplots must be an integer") if not isinstance(subplots_columns, int): raise TypeError("Number of columns of subplots must be an integer") if subplots_rows <= 0: raise ValueError("Number of rows of subplots must be positive") if subplots_columns <= 0: raise ValueError("Number of columns of subplots must be positive") self._fig, self._axes = plt.subplots(subplots_rows, subplots_columns) self._size = subplots_columns * subplots_rows self._nrows = subplots_rows self._ncolumns = subplots_columns # we store a figure object and multiple axes objects # Ensure self._axes array is always 2D if self._nrows == 1 and self._ncolumns == 1: self._axes = np.array(self._axes)[np.newaxis, np.newaxis] elif self._nrows == 1: self._axes = np.array(self._axes)[np.newaxis, :] elif self._ncolumns == 1: self._axes = np.array(self._axes)[:, np.newaxis]
def __getitem__(self, index): """If figure = ConfigurablePlotter(), then figure.begin(). Figure[0] will return the matplot figure, and figure[1] will return the subplot axis objects""" if index == 0: return self._fig elif index == 1: return self._axes else: raise ValueError("Index must be 0 (for figure) or 1 (for axes)")
[docs] def add_data_to_plot( self, times: np.ndarray, data_array: np.ndarray, position: list = [0, 0], xlabel: str = "time", ylabels: list = [], colours: list = [], new_axis=False, ): """Main code to add new data into the plot :params:: times: np.ndarray, independent x- variable :params:: data_array: np.ndarray, multiple dependent y- variables Data should has one row per timestep, and one column for each dependent variable :params:: position: list of integers, gives index of subplot to use :params:: xlabel: str :params:: ylabel: list of strings (a single string is also accepted) :params:: colours: list of valid colour specifiers (ie strings or rgb tuples) :params:: new_axis: boolean, set to true if data should be plotted on a second x axis""" if len(data_array.shape) == 1: # Turn any 1D input into 2D if (not isinstance(times, np.ndarray) or np.sum(np.shape(times)) == 1): # I.e. if only one np.int, or one element array times = np.array(times, ndmin=2) data_array = data_array[np.newaxis, :] else: data_array = data_array[:, np.newaxis] assert ( times.shape[0] == data_array.shape[0] ), "data and times are not the same length" data_width = data_array.shape[1] # saves the number of y-var assert ( position[0] < self._nrows and position[1] < self._ncolumns ), "position and shape are not compatible" if new_axis: axis = self._axes[position[0], position[1]].twinx() else: axis = self._axes[position[0], position[1]] # Format user inputs if len(colours) == 0: # Default value, if no colous specified colours = plt.cm.viridis(np.linspace(0, 1, data_width)) else: colours = matplotlib.colors.to_rgba_array(colours) assert data_width == np.shape(colours)[0], \ 'Unexpected number of colours' if isinstance(ylabels, str): ylabels = [ylabels] # Converts string input to list try: iter(ylabels) except TypeError: raise TypeError('Unexpected type of ylabels') # Plot over data array iteratively if len(ylabels) > 0: # If ylabels have been specified for inclusion assert data_width == len(ylabels), 'Unexpected number of ylabels' for i in range(data_width): axis.plot(times, data_array[:, i], color=colours[i], label=ylabels[i]) axis.legend() else: # Plot without a figure legend for i in range(data_width): axis.plot(times, data_array[:, i], color=colours[i]) plt.xlabel(xlabel) self._fig.tight_layout() return self._fig, self._axes
[docs] def add_fill( self, times: np.ndarray, ymin: np.ndarray, ymax: np.ndarray, position: list = [0, 0], xlabel: str = "time", ylabel: str = "number of people", colour: str = ["b"], alpha: float = 0.2, ): """Code to plot shaded region between two datasets :params:: times: np.ndarray, independent x- variable :params:: ymin: np.ndarray, dependent y- variables :params:: ymin: np.ndarray, comparison y- variables :params:: position: list of integers, gives index of subplot to use :params:: xlabel: str :params:: ylabel: list of strings :params:: colour: any valid colour specifier :params:: alpha: float, indicate transparency of filled region N.B While it is recommended that y_min should be the (generally) smaller dataset for readability, this is not required, and the datasets may cross (i.e. y_min may be larger in sections)""" assert ( position[0] < self._nrows and position[1] < self._ncolumns ), "position and shape are not compatible" axis = self._axes[position[0], position[1]] # plots the data axis.fill_between( times, np.squeeze(ymin), np.squeeze(ymax), color=colour, alpha=alpha, label=ylabel, ) axis.legend() plt.xlabel(xlabel) self._fig.tight_layout() return self._fig, self._axes
def show(self): plt.show() def write_to_file(self, filename: str = "SEIR_stochastic_simulation.pdf"): self._fig.savefig(filename) def __del__(self): if hasattr(self, "_fig"): plt.close(self._fig) # Close figure upon deletion