Source code for epynn.commons.plot

# EpyNN/epynn/commons/plot.py
# Standard library imports
import os

# Related third party imports
from matplotlib import pyplot as plt

# Local application/library specific imports
from epynn.commons.logs import process_logs
 

[docs]def pyplot_metrics(model, path): """Plot metrics/costs from training with matplotlib. :param model: An instance of EpyNN network object. :type model: :class:`epynn.meta.models.EpyNN` :param path: Write matplotlib plot. :type path: bool or NoneType """ plt.figure() metrics = model.metrics # Contains metrics and cost # Iterate over metrics/costs for s in metrics.keys(): # Iterate over active datasets for k, dset in enumerate(model.embedding.dsets): dname = dset.name x = [x for x in range(len(metrics[s][k]))] # X range y = metrics[s][k] # Y values from metrics[idx_dataset][idx_metrics] plt.plot(x, y, label=dname + ' ' + s) plt.legend() plt.xlabel('Epoch') plt.ylabel('Value') plt.title(model.uname) plt.show() # If path sets to None, set to defaults - Note path can be set to False, which makes it print only if path == None: path = 'plots' plot_path = os.path.join(os.getcwd(), path, model.uname) + '.png' if path: plt.savefig(plot_path) process_logs('Make: ' + plot_path, level=1) plt.close() return None