1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00
faceswap/lib/gui/display_graph.py
2022-10-17 18:14:04 +01:00

581 lines
21 KiB
Python
Executable file

#!/usr/bin python3
""" Graph functions for Display Frame area of the Faceswap GUI """
import datetime
import logging
import os
import tkinter as tk
from tkinter import ttk
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
from math import ceil, floor
import numpy as np
import matplotlib
from matplotlib import style
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
NavigationToolbar2Tk)
from matplotlib.backend_bases import NavigationToolbar2
from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask
if TYPE_CHECKING:
from matplotlib.lines import Line2D
matplotlib.use("TkAgg")
logger: logging.Logger = logging.getLogger(__name__)
class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
""" Base class for matplotlib line graphs.
Parameters
----------
parent: :class:`tkinter.ttk.Frame`
The parent frame that holds the graph
data: :class:`lib.gui.analysis.stats.Calculations`
The statistics class that holds the data to be displayed
ylabel: str
The data label for the y-axis
"""
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent)
style.use("ggplot")
self._calcs = data
self._ylabel = ylabel
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
self._lines: List["Line2D"] = []
self._toolbar: Optional["NavigationToolbar"] = None
self._fig = Figure(figsize=(4, 4), dpi=75)
self._ax1 = self._fig.add_subplot(1, 1, 1)
self._plotcanvas = FigureCanvasTkAgg(self._fig, self)
self._initiate_graph()
self._update_plot(initiate=True)
logger.debug("Initialized %s", self.__class__.__name__)
@property
def calcs(self):
""" :class:`lib.gui.analysis.stats.Calculations`. The calculated statistics associated with
this graph. """
return self._calcs
def _initiate_graph(self) -> None:
""" Place the graph canvas """
logger.debug("Setting plotcanvas")
self._plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True)
self._fig.subplots_adjust(left=0.100,
bottom=0.100,
right=0.95,
top=0.95,
wspace=0.2,
hspace=0.2)
logger.debug("Set plotcanvas")
def _update_plot(self, initiate: bool = True) -> None:
""" Update the plot with incoming data
Parameters
----------
initiate: bool, Optional
Whether the graph should be initialized for the first time (``True``) or data is being
updated for an existing graph (``False``). Default: ``True``
"""
logger.trace("Updating plot") # type:ignore
if initiate:
logger.debug("Initializing plot")
self._lines = []
self._ax1.clear()
self._axes_labels_set()
logger.debug("Initialized plot")
fulldata = list(self._calcs.stats.values())
self._axes_limits_set(fulldata)
if self._calcs.start_iteration > 0:
end_iteration = self._calcs.start_iteration + self._calcs.iterations
xrng = list(range(self._calcs.start_iteration, end_iteration))
else:
xrng = list(range(self._calcs.iterations))
keys = list(self._calcs.stats.keys())
for idx, item in enumerate(self._lines_sort(keys)):
if initiate:
self._lines.extend(self._ax1.plot(xrng, self._calcs.stats[item[0]],
label=item[1], linewidth=item[2], color=item[3]))
else:
self._lines[idx].set_data(xrng, self._calcs.stats[item[0]])
if initiate:
self._legend_place()
logger.trace("Updated plot") # type:ignore
def _axes_labels_set(self) -> None:
""" Set the X and Y axes labels. """
logger.debug("Setting axes labels. y-label: '%s'", self._ylabel)
self._ax1.set_xlabel("Iterations")
self._ax1.set_ylabel(self._ylabel)
def _axes_limits_set_default(self) -> None:
""" Set the default axes limits for the X and Y axes. """
logger.debug("Setting default axes ranges")
self._ax1.set_ylim(0.00, 100.0)
self._ax1.set_xlim(0, 1)
def _axes_limits_set(self, data: List[float]) -> None:
""" Set the axes limits.
Parameters
----------
data: list
The data points for the Y Axis
"""
xmin = self._calcs.start_iteration
if self._calcs.start_iteration > 0:
xmax = self._calcs.iterations + self._calcs.start_iteration
else:
xmax = self._calcs.iterations
xmax = max(1, xmax - 1)
if data:
ymin, ymax = self._axes_data_get_min_max(data)
self._ax1.set_ylim(ymin, ymax)
self._ax1.set_xlim(xmin, xmax)
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", # type:ignore
ymin, ymax, xmax)
else:
self._axes_limits_set_default()
@staticmethod
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]:
""" Obtain the minimum and maximum values for the y-axis from the given data points.
Parameters
----------
data: list
The data points for the Y Axis
Returns
-------
tuple
The minimum and maximum values for the y axis
"""
ymins, ymaxs = [], []
for item in data: # TODO Handle as array not loop
ymins.append(np.nanmin(item) * 1000)
ymaxs.append(np.nanmax(item) * 1000)
ymin = floor(min(ymins)) / 1000
ymax = ceil(max(ymaxs)) / 1000
logger.trace("ymin: %s, ymax: %s", ymin, ymax) # type:ignore
return ymin, ymax
def _axes_set_yscale(self, scale: str) -> None:
""" Set the Y-Scale to log or linear
Parameters
----------
scale: str
Should be one of ``"log"`` or ``"linear"``
"""
logger.debug("yscale: '%s'", scale)
self._ax1.set_yscale(scale)
def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]:
""" Sort the data keys into consistent order and set line color map and line width.
Parameters
----------
keys: list
The list of data point keys
Returns
-------
list
A list of loss keys with their corresponding line formatting and color information
"""
logger.trace("Sorting lines") # type:ignore
raw_lines: List[List[str]] = []
sorted_lines: List[List[str]] = []
for key in sorted(keys):
title = key.replace("_", " ").title()
if key.startswith("raw"):
raw_lines.append([key, title])
else:
sorted_lines.append([key, title])
groupsize = self._lines_groupsize(raw_lines, sorted_lines)
sorted_lines = raw_lines + sorted_lines
lines = self._lines_style(sorted_lines, groupsize)
return lines
@staticmethod
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int:
""" Get the number of items in each group.
If raw data isn't selected, then check the length of remaining groups until something is
found.
Parameters
----------
raw_lines: list
The list of keys for the raw data points
sorted_lines:
The list of sorted line keys to display on the graph
Returns
-------
int
The size of each group that exist within the data set.
"""
groupsize = 1
if raw_lines:
groupsize = len(raw_lines)
elif sorted_lines:
keys = [key[0][:key[0].find("_")] for key in sorted_lines]
distinct_keys = set(keys)
groupsize = len(keys) // len(distinct_keys)
logger.trace(groupsize) # type:ignore
return groupsize
def _lines_style(self,
lines: List[List[str]],
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]:
""" Obtain the color map and line width for each group.
Parameters
----------
lines: list
The list of sorted line keys to display on the graph
groupsize: int
The size of each group to display in the graph
Returns
-------
list
A list of loss keys with their corresponding line formatting and color information
"""
logger.trace("Setting lines style") # type:ignore
groups = int(len(lines) / groupsize)
colours = self._lines_create_colors(groupsize, groups)
widths = list(range(1, groups + 1))
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines)
for idx, item in enumerate(retval):
linewidth = widths[idx // groupsize]
item.extend((linewidth, colours[idx]))
return retval
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]:
""" Create the color maps.
Parameters
----------
groupsize: int
The size of each group to display in the graph
groups: int
The total number of groups to graph
Returns
-------
list
The colour map for each group
"""
colours = []
for i in range(1, groups + 1):
for colour in self._colourmaps[0:groupsize]:
cmap = matplotlib.cm.get_cmap(colour)
cpoint = 1 - (i / 5)
colours.append(cmap(cpoint))
logger.trace(colours) # type:ignore
return colours
def _legend_place(self) -> None:
""" Place and format the graph legend """
logger.debug("Placing legend")
self._ax1.legend(loc="upper right", ncol=2)
def _toolbar_place(self, parent: ttk.Frame) -> None:
""" Add Graph Navigation toolbar.
Parameters
----------
parent: ttk.Frame
The parent graph frame to place the toolbar onto
"""
logger.debug("Placing toolbar")
self._toolbar = NavigationToolbar(self._plotcanvas, parent)
self._toolbar.pack(side=tk.BOTTOM)
self._toolbar.update()
def clear(self) -> None:
""" Clear the graph plots from RAM """
logger.debug("Clearing graph from RAM: %s", self)
self._fig.clf()
del self._fig
class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
""" Live graph to be displayed during training.
Parameters
----------
parent: :class:`tkinter.ttk.Frame`
The parent frame that holds the graph
data: :class:`lib.gui.analysis.stats.Calculations`
The statistics class that holds the data to be displayed
ylabel: str
The data label for the y-axis
"""
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
super().__init__(parent, data, ylabel)
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask
self._displayed_keys: List[str] = []
self._add_callback()
def _add_callback(self) -> None:
""" Add the variable trace to update graph on refresh button press or save iteration. """
get_config().tk_vars.refresh_graph.trace("w", self.refresh) # type:ignore
def build(self) -> None:
""" Build the Training graph. """
logger.debug("Building training graph")
self._plotcanvas.draw()
logger.debug("Built training graph")
def refresh(self, *args) -> None: # pylint: disable=unused-argument
""" Read the latest loss data and apply to current graph """
refresh_var = cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
if not refresh_var.get() and self._thread is None:
return
if self._thread is None:
logger.debug("Updating plot data")
self._thread = LongRunningTask(target=self._calcs.refresh)
self._thread.start()
self.after(1000, self.refresh)
elif not self._thread.complete.is_set():
logger.debug("Graph Data not yet available")
self.after(1000, self.refresh)
else:
logger.debug("Updating plot with data from background thread")
self._calcs = self._thread.get_result() # Terminate the LongRunningTask object
self._thread = None
dsp_keys = list(sorted(self._calcs.stats))
if dsp_keys != self._displayed_keys:
logger.debug("Reinitializing graph for keys change. Old keys: %s New keys: %s",
self._displayed_keys, dsp_keys)
initiate = True
self._displayed_keys = dsp_keys
else:
initiate = False
self._update_plot(initiate=initiate)
self._plotcanvas.draw()
refresh_var.set(False)
def save_fig(self, location: str) -> None:
""" Save the current graph to file
Parameters
----------
location: str
The full path to the folder where the current graph should be saved
"""
logger.debug("Saving graph: '%s'", location)
keys = sorted([key.replace("raw_", "") for key in self._calcs.stats.keys()
if key.startswith("raw_")])
filename = " - ".join(keys)
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(location, f"{filename}_{now}.png")
self._fig.set_size_inches(16, 9)
self._fig.savefig(filename, bbox_inches="tight", dpi=120)
print(f"Saved graph to {filename}")
logger.debug("Saved graph: '%s'", filename)
self._resize_fig()
def _resize_fig(self) -> None:
""" Resize the figure to the current canvas size. """
class Event(): # pylint: disable=too-few-public-methods
""" Event class that needs to be passed to plotcanvas.resize """
pass # pylint: disable=unnecessary-pass
setattr(Event, "width", self.winfo_width())
setattr(Event, "height", self.winfo_height())
self._plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
""" Session Graph for session pop-up.
Parameters
----------
parent: :class:`tkinter.ttk.Frame`
The parent frame that holds the graph
data: :class:`lib.gui.analysis.stats.Calculations`
The statistics class that holds the data to be displayed
ylabel: str
The data label for the y-axis
scale: str
Should be one of ``"log"`` or ``"linear"``
"""
def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None:
super().__init__(parent, data, ylabel)
self._scale = scale
def build(self) -> None:
""" Build the session graph """
logger.debug("Building session graph")
self._toolbar_place(self)
self._plotcanvas.draw()
logger.debug("Built session graph")
def refresh(self, data, ylabel: str, scale: str) -> None:
""" Refresh the Session Graph's data.
Parameters
----------
data: :class:`lib.gui.analysis.stats.Calculations`
The statistics class that holds the data to be displayed
ylabel: str
The data label for the y-axis
scale: str
Should be one of ``"log"`` or ``"linear"``
"""
logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale)
self._calcs = data
self._ylabel = ylabel
self.set_yscale_type(scale)
logger.debug("Refreshed session graph")
def set_yscale_type(self, scale: str) -> None:
""" Set the scale type for the y-axis and redraw.
Parameters
----------
scale: str
Should be one of ``"log"`` or ``"linear"``
"""
logger.debug("Updating scale type: '%s'", scale)
self._scale = scale
self._update_plot(initiate=True)
self._axes_set_yscale(self._scale)
self._plotcanvas.draw()
logger.debug("Updated scale type")
class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors
""" Overrides the default Navigation Toolbar to provide only the buttons we require
and to layout the items in a consistent manner with the rest of the GUI for the Analysis
Session Graph pop up Window.
Parameters
----------
canvas: :class:`matplotlib.backends.backend_tkagg.FigureCanvasTkAgg`
The canvas that holds the displayed graph and will hold the toolbar
window: :class:`~lib.gui.display_graph.SessionGraph`
The Session Graph canvas
pack_toolbar: bool, Optional
Whether to pack the Tool bar or not. Default: ``True``
"""
toolitems = [t for t in NavigationToolbar2Tk.toolitems if
t[0] in ("Home", "Pan", "Zoom", "Save")]
def __init__(self, # pylint: disable=super-init-not-called
canvas: FigureCanvasTkAgg,
window: ttk.Frame,
*,
pack_toolbar: bool = True) -> None:
# Avoid using self.window (prefer self.canvas.get_tk_widget().master),
# so that Tool implementations can reuse the methods.
ttk.Frame.__init__(self, # pylint:disable=non-parent-init-called
master=window,
width=int(canvas.figure.bbox.width),
height=50)
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP)
btnframe = ttk.Frame(self) # Add a button frame to consistently line up GUI
btnframe.pack(fill=tk.X, padx=5, pady=5, side=tk.RIGHT)
self._buttons = {}
for text, tooltip_text, image_file, callback in self.toolitems:
self._buttons[text] = button = self._Button(
btnframe,
text,
image_file,
toggle=callback in ["zoom", "pan"],
command=getattr(self, callback),
)
if tooltip_text is not None:
Tooltip(button, text=tooltip_text, wrap_length=200)
self.message = tk.StringVar(master=self)
self._message_label = ttk.Label(master=self, textvariable=self.message)
self._message_label.pack(side=tk.LEFT, padx=5) # Additional left padding
NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called
if pack_toolbar:
self.pack(side=tk.BOTTOM, fill=tk.X)
@staticmethod
def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ,arguments-renamed
text: str,
image_file: str,
toggle: bool,
command) -> Union[ttk.Button, ttk.Checkbutton]:
""" Override the default button method to use our icons and ttk widgets for
consistent GUI layout.
Parameters
----------
frame: :class:`tkinter.ttk.Frame`
The frame that holds the buttons
text: str
The display text for the button
image_file: str
The name of the image file to use
toggle: bool
Whether to use a checkbutton (``True``) or a regular button (``False``)
command: method
The Navigation Toolbar callback method
Returns
-------
:class:`tkinter.ttk.Button` or :class:`tkinter.ttk.Checkbutton`
The widger to use. A button if the option is not toggleable, a checkbutton if the
option is toggleable.
"""
iconmapping = {"home": "reload",
"filesave": "save",
"zoom_to_rect": "zoom"}
icon = iconmapping[image_file] if iconmapping.get(image_file, None) else image_file
img = get_images().icons[icon]
if not toggle:
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame,
text=text,
image=img,
command=command)
else:
var = tk.IntVar(master=frame)
btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var)
# Original implementation uses tk Checkbuttons which have a select and deselect
# method. These aren't available in ttk Checkbuttons, so we monkey patch the methods
# to update the underlying variable.
setattr(btn, "select", lambda i=1: var.set(i))
setattr(btn, "deselect", lambda i=0: var.set(i))
btn.pack(side=tk.RIGHT, padx=2)
return btn