Source code for ztlearn.utils.plot_utils

# -*- coding: utf-8 -*-

import time
import matplotlib.pyplot as plt

SMALL_FONT = 10
LARGE_FONT = 14
FIG_SIZE   = (8, 6)

img_specs = {
    'mnist' :  {
        'pix_row'    : 1,
        'pix_col'    : 26,
        'img_width'  : 28,
        'img_height' : 28
    },
    'cifar' :  {
        'pix_row'    : 1,
        'pix_col'    : 30,
        'img_width'  : 32,
        'img_height' : 32
    },
    'digits':  {
        'pix_row'    : 0,
        'pix_col'    : 7,
        'img_width'  : 8,
        'img_height' : 8
    }
}


[docs]def plotter(x, y = [], plot_dict = {}, fig_dims = (7, 5), title = 'Model', title_dict = {}, ylabel = 'y-axis', ylabel_dict = {}, xlabel = 'x-axis', xlabel_dict = {}, legend = [], # ['train', 'valid'], legend_dict = {}, file_path = '', to_save = False, plot_type = 'line', cmap_name = None, cmap_number = 10, grid_on = True): fig, ax = plt.subplots() fig.set_size_inches(fig_dims) ax.set_axisbelow(True) ax.minorticks_on() if grid_on: ax.grid(which = 'major', linestyle = '-', linewidth = 0.5, color = 'grey') ax.grid(which = 'minor', linestyle = ':', linewidth = 0.5, color = 'red') if plot_type == 'line': for i in range(len(y)): ax.plot(x, y[i], **plot_dict) if plot_type == 'scatter': if cmap_name is not None: plot_dict.update(cmap = plt.cm.get_cmap(cmap_name, cmap_number)) plot = ax.scatter(x[:, 0], x[:, 1], **plot_dict) fig.colorbar(plot, ax = ax) else: ax.scatter(x[:, 0], x[:, 1], **plot_dict) if y is not None: ax.scatter(y[:, 0], y[:, 1], **{'c' : 'red'}) # centroids for k-means ax.set_title(title, **title_dict) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.legend(legend, **legend_dict) if to_save: fig.savefig(file_path) return plt
[docs]def plot_pca(components, n_components = 2, colour_array = None, model_name = 'PCA', to_save = False, fig_dims = FIG_SIZE, #(10, 8), title_dict = {'size' : SMALL_FONT}): file_path = '../plots/decompositions/'+('{}{}{}{}{}'.format(model_name, '_', n_components, '_Components_', time.strftime("%Y-%m-%d_%H-%M-%S"),'.png')) plt_dict = { 'c' : colour_array, 'edgecolor' : 'none', 'alpha' : 0.5, 's' : 50 } plt = plotter(components, y = None, plot_dict = plt_dict, fig_dims = fig_dims, title = 'Model {}'.format(model_name.upper()), title_dict = title_dict, xlabel = 'PC 1', ylabel = 'PC 2', file_path = file_path, to_save = to_save, plot_type = 'scatter', cmap_name = 'tab10', cmap_number = 10, grid_on = False) plt.show()
[docs]def plot_kmeans(data, labels = None, centroids = None, model_name = 'K-Means', model_clusters = 1, to_save = False, fig_dims = FIG_SIZE, title_dict = {'size' : SMALL_FONT}): file_path = '../plots/clusters/'+('{}{}{}{}{}'.format(model_name, '_', model_clusters, '_Clusters_', time.strftime("%Y-%m-%d_%H-%M-%S"),'.png')) plt = plotter(data, y = centroids, plot_dict = {'c' : labels}, fig_dims = fig_dims, title = 'Model {}'.format(model_name.title()), title_dict = title_dict, file_path = file_path, to_save = to_save, plot_type = 'scatter') plt.show()
[docs]def plot_metric(metric, epoch, train, valid, model_name = '', to_save = False, plot_dict = {'linewidth' : 0.8}, fig_dims = FIG_SIZE, title_dict = {'size' : SMALL_FONT}, ylabel_dict = {'size' : SMALL_FONT}, xlabel_dict = {'size' : SMALL_FONT}, legend = ['train', 'valid'], legend_dict = {'loc' : 'upper right'}): file_path = '../plots/metrics/'+('{}{}{}{}{}'.format(model_name, '_', metric, '_', time.strftime("%Y-%m-%d_%H-%M-%S"),'.png')) plt = plotter(range(epoch), [train, valid], plot_dict = plot_dict, fig_dims = fig_dims, title = 'Model {}'.format(metric.title()), title_dict = title_dict, ylabel = metric.title(), ylabel_dict = ylabel_dict, xlabel = 'Iterations', xlabel_dict = xlabel_dict, legend = legend, legend_dict = legend_dict, file_path = file_path, to_save = to_save) plt.show()
[docs]def plot_opt_viz(dims, x, y, z, f_solution, overlay = 'plot', to_save = False, title = 'Optimization', title_dict = {'size' : LARGE_FONT}, fig_dims = FIG_SIZE, xticks_dict = {'size' : LARGE_FONT}, yticks_dict = {'size' : LARGE_FONT}, xlabel = r'$\theta^1$', xlabel_dict = {'size' : LARGE_FONT}, ylabel = r'$\theta^2$', ylabel_dict = {'size' : LARGE_FONT}, legend = ['train', 'valid'], legend_dict = {}): if dims == 3: fig = plt.figure(figsize = fig_dims) if overlay == 'wireframe': from mpl_toolkits.mplot3d import axes3d # for 3d projections ax = fig.add_subplot(111, projection = '3d') plt.scatter(y[:,0], y[:,1], s = f_solution, c = 'r') ax.plot_wireframe(x[0], x[1], z, rstride = 5, cstride = 5, linewidth = 0.5) elif overlay == 'contour': ax = fig.add_subplot(111) plt.scatter(y[:,0], y[:,1], s = f_solution, c = 'r') ax.contour(x[0], x[1], z, 20, cmap = plt.cm.jet) ax.set_xlabel(xlabel, **xlabel_dict) ax.set_ylabel(ylabel, **ylabel_dict) elif dims == 2: plt.figure(figsize = fig_dims) plt.xticks(**xticks_dict) plt.yticks(**yticks_dict) plt.plot(x, y) plt.scatter(z, f_solution, color = 'r') plt.xlabel(xlabel, **xlabel_dict) plt.ylabel(ylabel, **ylabel_dict) if to_save: plt.suptitle(('{}{}'.format(dims, 'D Surfaces')), fontsize = 14) plt.savefig('../plots/'+('{}{}{}{}'.format(overlay, '_', dims, 'd.png'))) plt.show()
[docs]def plot_img_samples(train_data, train_target = None, fig_dims = (6, 6), dataset = 'digits', channels = 1): fig = plt.figure(figsize = fig_dims) fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1, hspace = 0.05, wspace = 0.05) for i in range(36): digit = fig.add_subplot(6, 6, i+1, xticks = [], yticks = []) if channels == 3: color_img = train_data[i].reshape(channels, img_specs[dataset]['img_height'], img_specs[dataset]['img_width']).transpose([1, 2, 0]) digit.imshow(color_img, interpolation = 'nearest') else: digit.imshow(train_data[i].reshape(img_specs[dataset]['img_height'], img_specs[dataset]['img_width']), cmap = plt.cm.binary, interpolation = 'nearest') if train_target is not None: digit.text(img_specs[dataset]['pix_row'], img_specs[dataset]['pix_col'], str(train_target.astype('int')[i])) plt.show()
[docs]def plot_tiled_img_samples(train_data, train_target = None, fig_dims = (6, 6), dataset = 'digits', channels = 1): fig = plt.figure(figsize = fig_dims) fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1, hspace = 0.05, wspace = 0.05) for i in range(36): digit = fig.add_subplot(6, 6, i+1) digit.grid(which = 'major', linestyle = ':', linewidth = 0.5, color = 'blue') digit.grid(which = 'minor', linestyle = ':', linewidth = 0.5, color = 'blue') digit.xaxis.set_ticklabels([]) digit.yaxis.set_ticklabels([]) digit.minorticks_on() if channels == 3: color_img = train_data[i].reshape(channels, 32, 32).transpose([1, 2, 0]) digit.imshow(color_img, interpolation = 'nearest') else: digit.imshow(train_data[i].reshape(img_specs[dataset]['img_height'], img_specs[dataset]['img_width']), cmap = plt.cm.binary, interpolation = 'nearest') if train_target is not None: digit.text(img_specs[dataset]['pix_row'], img_specs[dataset]['pix_col'], str(train_target.astype('int')[i])) plt.show()
[docs]def plot_img_results(test_data, test_label, predictions, fig_dims = (6, 6), dataset = 'digits', channels = 1): fig = plt.figure(figsize = fig_dims) fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1, hspace = 0.05, wspace = 0.05) for i in range(36): digit = fig.add_subplot(6, 6, i + 1, xticks = [], yticks = []) if channels == 3: color_img = test_data[i].reshape(channels, img_specs[dataset]['img_height'], img_specs[dataset]['img_width']).transpose([1, 2, 0]) digit.imshow(color_img, interpolation = 'nearest') else: digit.imshow(test_data.reshape(-1, img_specs[dataset]['img_height'], img_specs[dataset]['img_width'])[i], cmap = plt.cm.binary, interpolation = 'nearest') if predictions[i] == test_label[i]: digit.text(img_specs[dataset]['pix_row'], img_specs[dataset]['pix_col'], str(predictions[i]), color = 'green') else: digit.text(img_specs[dataset]['pix_row'], img_specs[dataset]['pix_col'], str(predictions[i]), color = 'red') plt.show()
[docs]def plot_generated_img_samples(test_label, predictions, fig_dims = (6, 6), dataset = 'digits', channels = 1, to_save = False, iteration = 0, model_name = ''): fig = plt.figure(figsize = fig_dims) fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1, hspace = 0.05, wspace = 0.05) for i in range(36): digit = fig.add_subplot(6, 6, i+1, xticks = [], yticks = []) if channels == 3: color_img = predictions[i].reshape(channels, img_specs[dataset]['img_height'], img_specs[dataset]['img_width']).transpose([1, 2, 0]) digit.imshow(color_img, interpolation = 'nearest') else: digit.imshow(predictions.reshape(-1, img_specs[dataset]['img_height'], img_specs[dataset]['img_width'])[i], cmap = plt.cm.binary, interpolation = 'nearest') if test_label is not None: digit.text(img_specs[dataset]['pix_row'], img_specs[dataset]['pix_col'], str(test_label[i]), color = 'blue') if to_save: plt.suptitle(('{}{}'.format('Generator Epoch: ', iteration)), y = 1.05, fontsize = 12).set_color('blue') plt.savefig('../plots/generated/'+('{}{}{}'.format(model_name, '_', iteration, '.png'))) plt.show(block = False) if to_save else plt.show(block = True)
[docs]def plot_regression_results(train_data, train_label, test_data, test_label, input_data, pred_line, mse, super_title, y_label, x_label, model_name = '', to_save = False, fig_dims = FIG_SIZE, font_size = 10): plt.figure(figsize = fig_dims) cmap = plt.get_cmap('summer') train = plt.scatter(train_data, train_label, color = cmap(0.8), s = 12) test = plt.scatter(test_data, test_label, color = cmap(0.4), s = 12) # minimum parameters: plt.plot(input_data, pred_line, '*', color = 'green', markersize = 4) plt.plot(input_data, pred_line, marker = '*', color = 'green', markersize = 4, linestyle = 'none') plt.suptitle(super_title) if mse is not None: plt.title("MSE: {:4.2f}".format(mse), size = font_size) plt.xlabel(x_label) plt.ylabel(y_label) plt.legend((train, test), ("Train", "Test"), loc = 'upper left') if to_save: plt.savefig('../plots/metrics/'+('{}{}{:4.2f}{}{}{}'.format(model_name, '_mse_', mse, '_', time.strftime("%Y-%m-%d_%H-%M-%S"),'.png'))) plt.show()