#!/usr/bin python3 """ Graph functions for Display Frame area of the Faceswap GUI """ import datetime import logging import os import tkinter as tk from tkinter import ttk from typing import Union, List, Tuple from math import ceil, floor import numpy as np import matplotlib from matplotlib import style from matplotlib.figure import Figure from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg, NavigationToolbar2Tk) from matplotlib.backend_bases import NavigationToolbar2 from .custom_widgets import Tooltip from .utils import get_config, get_images, LongRunningTask matplotlib.use("TkAgg") logger: logging.Logger = logging.getLogger(__name__) class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors """ Base class for matplotlib line graphs. Parameters ---------- parent: :class:`tkinter.ttk.Frame` The parent frame that holds the graph data: :class:`lib.gui.analysis.stats.Calculations` The statistics class that holds the data to be displayed ylabel: str The data label for the y-axis """ def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: logger.debug("Initializing %s", self.__class__.__name__) super().__init__(parent) style.use("ggplot") self._calcs = data self._ylabel = ylabel self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper", "summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"] self._lines = [] self._toolbar = None self._fig = Figure(figsize=(4, 4), dpi=75) self._ax1 = self._fig.add_subplot(1, 1, 1) self._plotcanvas = FigureCanvasTkAgg(self._fig, self) self._initiate_graph() self._update_plot(initiate=True) logger.debug("Initialized %s", self.__class__.__name__) @property def calcs(self): """ :class:`lib.gui.analysis.stats.Calculations`. The calculated statistics associated with this graph. """ return self._calcs def _initiate_graph(self) -> None: """ Place the graph canvas """ logger.debug("Setting plotcanvas") self._plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True) self._fig.subplots_adjust(left=0.100, bottom=0.100, right=0.95, top=0.95, wspace=0.2, hspace=0.2) logger.debug("Set plotcanvas") def _update_plot(self, initiate: bool = True) -> None: """ Update the plot with incoming data Parameters ---------- initiate: bool, Optional Whether the graph should be initialized for the first time (``True``) or data is being updated for an existing graph (``False``). Default: ``True`` """ logger.trace("Updating plot") if initiate: logger.debug("Initializing plot") self._lines = [] self._ax1.clear() self._axes_labels_set() logger.debug("Initialized plot") fulldata = list(self._calcs.stats.values()) self._axes_limits_set(fulldata) if self._calcs.start_iteration > 0: end_iteration = self._calcs.start_iteration + self._calcs.iterations xrng = list(range(self._calcs.start_iteration, end_iteration)) else: xrng = list(range(self._calcs.iterations)) keys = list(self._calcs.stats.keys()) for idx, item in enumerate(self._lines_sort(keys)): if initiate: self._lines.extend(self._ax1.plot(xrng, self._calcs.stats[item[0]], label=item[1], linewidth=item[2], color=item[3])) else: self._lines[idx].set_data(xrng, self._calcs.stats[item[0]]) if initiate: self._legend_place() logger.trace("Updated plot") def _axes_labels_set(self) -> None: """ Set the X and Y axes labels. """ logger.debug("Setting axes labels. y-label: '%s'", self._ylabel) self._ax1.set_xlabel("Iterations") self._ax1.set_ylabel(self._ylabel) def _axes_limits_set_default(self) -> None: """ Set the default axes limits for the X and Y axes. """ logger.debug("Setting default axes ranges") self._ax1.set_ylim(0.00, 100.0) self._ax1.set_xlim(0, 1) def _axes_limits_set(self, data: List[float]) -> None: """ Set the axes limits. Parameters ---------- data: list The data points for the Y Axis """ xmin = self._calcs.start_iteration if self._calcs.start_iteration > 0: xmax = self._calcs.iterations + self._calcs.start_iteration else: xmax = self._calcs.iterations xmax = max(1, xmax - 1) if data: ymin, ymax = self._axes_data_get_min_max(data) self._ax1.set_ylim(ymin, ymax) self._ax1.set_xlim(xmin, xmax) logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax) else: self._axes_limits_set_default() @staticmethod def _axes_data_get_min_max(data: List[float]) -> Tuple[float]: """ Obtain the minimum and maximum values for the y-axis from the given data points. Parameters ---------- data: list The data points for the Y Axis Returns ------- tuple The minimum and maximum values for the y axis """ ymin, ymax = [], [] for item in data: # TODO Handle as array not loop ymin.append(np.nanmin(item) * 1000) ymax.append(np.nanmax(item) * 1000) ymin = floor(min(ymin)) / 1000 ymax = ceil(max(ymax)) / 1000 logger.trace("ymin: %s, ymax: %s", ymin, ymax) return ymin, ymax def _axes_set_yscale(self, scale: str) -> None: """ Set the Y-Scale to log or linear Parameters ---------- scale: str Should be one of ``"log"`` or ``"linear"`` """ logger.debug("yscale: '%s'", scale) self._ax1.set_yscale(scale) def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]: """ Sort the data keys into consistent order and set line color map and line width. Parameters ---------- keys: list The list of data point keys Returns ------- list A list of loss keys with their corresponding line formatting and color information """ logger.trace("Sorting lines") raw_lines = [] sorted_lines = [] for key in sorted(keys): title = key.replace("_", " ").title() if key.startswith("raw"): raw_lines.append([key, title]) else: sorted_lines.append([key, title]) groupsize = self._lines_groupsize(raw_lines, sorted_lines) sorted_lines = raw_lines + sorted_lines lines = self._lines_style(sorted_lines, groupsize) return lines @staticmethod def _lines_groupsize(raw_lines: List[str], sorted_lines: List[str]) -> int: """ Get the number of items in each group. If raw data isn't selected, then check the length of remaining groups until something is found. Parameters ---------- raw_lines: list The list of keys for the raw data points sorted_lines: The list of sorted line keys to display on the graph Returns ------- int The size of each group that exist within the data set. """ groupsize = 1 if raw_lines: groupsize = len(raw_lines) elif sorted_lines: keys = [key[0][:key[0].find("_")] for key in sorted_lines] distinct_keys = set(keys) groupsize = len(keys) // len(distinct_keys) logger.trace(groupsize) return groupsize def _lines_style(self, lines: List[str], groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]: """ Obtain the color map and line width for each group. Parameters ---------- lines: list The list of sorted line keys to display on the graph groupsize: int The size of each group to display in the graph Returns ------- list A list of loss keys with their corresponding line formatting and color information """ logger.trace("Setting lines style") groups = int(len(lines) / groupsize) colours = self._lines_create_colors(groupsize, groups) widths = list(range(1, groups + 1)) for idx, item in enumerate(lines): linewidth = widths[idx // groupsize] item.extend((linewidth, colours[idx])) return lines def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]: """ Create the color maps. Parameters ---------- groupsize: int The size of each group to display in the graph groups: int The total number of groups to graph Returns ------- list The colour map for each group """ colours = [] for i in range(1, groups + 1): for colour in self._colourmaps[0:groupsize]: cmap = matplotlib.cm.get_cmap(colour) cpoint = 1 - (i / 5) colours.append(cmap(cpoint)) logger.trace(colours) return colours def _legend_place(self) -> None: """ Place and format the graph legend """ logger.debug("Placing legend") self._ax1.legend(loc="upper right", ncol=2) def _toolbar_place(self, parent: ttk.Frame) -> None: """ Add Graph Navigation toolbar. Parameters ---------- parent: ttk.Frame The parent graph frame to place the toolbar onto """ logger.debug("Placing toolbar") self._toolbar = NavigationToolbar(self._plotcanvas, parent) self._toolbar.pack(side=tk.BOTTOM) self._toolbar.update() def clear(self) -> None: """ Clear the graph plots from RAM """ logger.debug("Clearing graph from RAM: %s", self) self._fig.clf() del self._fig class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors """ Live graph to be displayed during training. Parameters ---------- parent: :class:`tkinter.ttk.Frame` The parent frame that holds the graph data: :class:`lib.gui.analysis.stats.Calculations` The statistics class that holds the data to be displayed ylabel: str The data label for the y-axis """ def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: super().__init__(parent, data, ylabel) self._thread = None # Thread for LongRunningTask self._displayed_keys = [] self._add_callback() def _add_callback(self) -> None: """ Add the variable trace to update graph on refresh button press or save iteration. """ get_config().tk_vars["refreshgraph"].trace("w", self.refresh) def build(self) -> None: """ Build the Training graph. """ logger.debug("Building training graph") self._plotcanvas.draw() logger.debug("Built training graph") def refresh(self, *args) -> None: # pylint: disable=unused-argument """ Read the latest loss data and apply to current graph """ refresh_var = get_config().tk_vars["refreshgraph"] if not refresh_var.get() and self._thread is None: return if self._thread is None: logger.debug("Updating plot data") self._thread = LongRunningTask(target=self._calcs.refresh) self._thread.start() self.after(1000, self.refresh) elif not self._thread.complete.is_set(): logger.debug("Graph Data not yet available") self.after(1000, self.refresh) else: logger.debug("Updating plot with data from background thread") self._calcs = self._thread.get_result() # Terminate the LongRunningTask object self._thread = None dsp_keys = list(sorted(self._calcs.stats)) if dsp_keys != self._displayed_keys: logger.debug("Reinitializing graph for keys change. Old keys: %s New keys: %s", self._displayed_keys, dsp_keys) initiate = True self._displayed_keys = dsp_keys else: initiate = False self._update_plot(initiate=initiate) self._plotcanvas.draw() refresh_var.set(False) def save_fig(self, location: str) -> None: """ Save the current graph to file Parameters ---------- location: str The full path to the folder where the current graph should be saved """ logger.debug("Saving graph: '%s'", location) keys = sorted([key.replace("raw_", "") for key in self._calcs.stats.keys() if key.startswith("raw_")]) filename = " - ".join(keys) now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = os.path.join(location, f"{filename}_{now}.png") self._fig.set_size_inches(16, 9) self._fig.savefig(filename, bbox_inches="tight", dpi=120) print(f"Saved graph to {filename}") logger.debug("Saved graph: '%s'", filename) self._resize_fig() def _resize_fig(self) -> None: """ Resize the figure to the current canvas size. """ class Event(): # pylint: disable=too-few-public-methods """ Event class that needs to be passed to plotcanvas.resize """ pass # pylint: disable=unnecessary-pass Event.width = self.winfo_width() Event.height = self.winfo_height() self._plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors """ Session Graph for session pop-up. Parameters ---------- parent: :class:`tkinter.ttk.Frame` The parent frame that holds the graph data: :class:`lib.gui.analysis.stats.Calculations` The statistics class that holds the data to be displayed ylabel: str The data label for the y-axis scale: str Should be one of ``"log"`` or ``"linear"`` """ def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None: super().__init__(parent, data, ylabel) self._scale = scale def build(self): """ Build the session graph """ logger.debug("Building session graph") self._toolbar_place(self) self._plotcanvas.draw() logger.debug("Built session graph") def refresh(self, data, ylabel: str, scale: str) -> None: """ Refresh the Session Graph's data. Parameters ---------- data: :class:`lib.gui.analysis.stats.Calculations` The statistics class that holds the data to be displayed ylabel: str The data label for the y-axis scale: str Should be one of ``"log"`` or ``"linear"`` """ logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale) self._calcs = data self._ylabel = ylabel self.set_yscale_type(scale) logger.debug("Refreshed session graph") def set_yscale_type(self, scale: str) -> None: """ Set the scale type for the y-axis and redraw. Parameters ---------- scale: str Should be one of ``"log"`` or ``"linear"`` """ logger.debug("Updating scale type: '%s'", scale) self._scale = scale self._update_plot(initiate=True) self._axes_set_yscale(self._scale) self._plotcanvas.draw() logger.debug("Updated scale type") class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors """ Overrides the default Navigation Toolbar to provide only the buttons we require and to layout the items in a consistent manner with the rest of the GUI for the Analysis Session Graph pop up Window. Parameters ---------- canvas: :class:`matplotlib.backends.backend_tkagg.FigureCanvasTkAgg` The canvas that holds the displayed graph and will hold the toolbar window: :class:`~lib.gui.display_graph.SessionGraph` The Session Graph canvas pack_toolbar: bool, Optional Whether to pack the Tool bar or not. Default: ``True`` """ toolitems = [t for t in NavigationToolbar2Tk.toolitems if t[0] in ("Home", "Pan", "Zoom", "Save")] def __init__(self, # pylint: disable=super-init-not-called canvas: FigureCanvasTkAgg, window: SessionGraph, *, pack_toolbar: bool = True) -> None: # Avoid using self.window (prefer self.canvas.get_tk_widget().master), # so that Tool implementations can reuse the methods. ttk.Frame.__init__(self, # pylint:disable=non-parent-init-called master=window, width=int(canvas.figure.bbox.width), height=50) sep = ttk.Frame(self, height=2, relief=tk.RIDGE) sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) btnframe = ttk.Frame(self) # Add a button frame to consistently line up GUI btnframe.pack(fill=tk.X, padx=5, pady=5, side=tk.RIGHT) self._buttons = {} for text, tooltip_text, image_file, callback in self.toolitems: self._buttons[text] = button = self._Button( btnframe, text, image_file, toggle=callback in ["zoom", "pan"], command=getattr(self, callback), ) if tooltip_text is not None: Tooltip(button, text=tooltip_text, wrap_length=200) self.message = tk.StringVar(master=self) self._message_label = ttk.Label(master=self, textvariable=self.message) self._message_label.pack(side=tk.LEFT, padx=5) # Additional left padding NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called if pack_toolbar: self.pack(side=tk.BOTTOM, fill=tk.X) @staticmethod def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ text: str, image_file: str, toggle: bool, command) -> Union[ttk.Button, ttk.Checkbutton]: """ Override the default button method to use our icons and ttk widgets for consistent GUI layout. Parameters ---------- frame: :class:`tkinter.ttk.Frame` The frame that holds the buttons text: str The display text for the button image_file: str The name of the image file to use toggle: bool Whether to use a checkbutton (``True``) or a regular button (``False``) command: method The Navigation Toolbar callback method Returns ------- :class:`tkinter.ttk.Button` or :class:`tkinter.ttk.Checkbutton` The widger to use. A button if the option is not toggleable, a checkbutton if the option is toggleable. """ iconmapping = {"home": "reload", "filesave": "save", "zoom_to_rect": "zoom"} icon = iconmapping[image_file] if iconmapping.get(image_file, None) else image_file img = get_images().icons[icon] if not toggle: btn = ttk.Button(frame, text=text, image=img, command=command) else: var = tk.IntVar(master=frame) btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var) # Original implementation uses tk Checkbuttons which have a select and deselect # method. These aren't available in ttk Checkbuttons, so we monkey patch the methods # to update the underlying variable. setattr(btn, "select", lambda i=1: var.set(i)) setattr(btn, "deselect", lambda i=0: var.set(i)) btn.pack(side=tk.RIGHT, padx=2) return btn