1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 03:26:47 -04:00
faceswap/lib/gui/display_graph.py
torzdf d8557c1970
Faceswap 2.0 (#1045)
* Core Updates
    - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
    - Document lib.gpu_stats and lib.sys_info
    - Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
    - lib.gui.menu - typofix

* Update Dependencies
Bump Tensorflow Version Check

* Port extraction to tf2

* Add custom import finder for loading Keras or tf.keras depending on backend

* Add `tensorflow` to KerasFinder search path

* Basic TF2 training running

* model.initializers - docstring fix

* Fix and pass tests for tf2

* Replace Keras backend tests with faceswap backend tests

* Initial optimizers update

* Monkey patch tf.keras optimizer

* Remove custom Adam Optimizers and Memory Saving Gradients

* Remove multi-gpu option. Add Distribution to cli

* plugins.train.model._base: Add Mirror, Central and Default distribution strategies

* Update tensorboard kwargs for tf2

* Penalized Loss - Fix for TF2 and AMD

* Fix syntax for tf2.1

* requirements typo fix

* Explicit None for clipnorm if using a distribution strategy

* Fix penalized loss for distribution strategies

* Update Dlight

* typo fix

* Pin to TF2.2

* setup.py - Install tensorflow from pip if not available in Conda

* Add reduction options and set default for mirrored distribution strategy

* Explicitly use default strategy rather than nullcontext

* lib.model.backup_restore documentation

* Remove mirrored strategy reduction method and default based on OS

* Initial restructure - training

* Remove PingPong
Start model.base refactor

* Model saving and resuming enabled

* More tidying up of model.base

* Enable backup and snapshotting

* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly

* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers

* Apply custom Conv2D layer

* Finalize NNBlock restructure
Update Dfaker blocks

* Fix reloading model under a different distribution strategy

* Pass command line arguments through to trainer

* Remove training_opts from model and reference params directly

* Tidy up model __init__

* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning

* Fix timelapse

* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original

* dfl-h128 ported

* DFL SAE ported

* IAE Ported

* dlight ported

* port lightweight

* realface ported

* unbalanced ported

* villain ported

* lib.cli.args - Update Batchsize + move allow_growth to config

* Remove output shape definition
Get image sizes per side rather than globally

* Strip mask input from encoder

* Fix learn mask and output learned mask to preview

* Trigger Allow Growth prior to setting strategy

* Fix GUI Graphing

* GUI - Display batchsize correctly + fix training graphs

* Fix penalized loss

* Enable mixed precision training

* Update analysis displayed batch to match input

* Penalized Loss - Multi-GPU Fix

* Fix all losses for TF2

* Fix Reflect Padding

* Allow different input size for each side of the model

* Fix conv-aware initialization on reload

* Switch allow_growth order

* Move mixed_precision to cli

* Remove distrubution strategies

* Compile penalized loss sub-function into LossContainer

* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels

* Add ability to refresh preview on demand on pop-up window

* Enable refresh of training preview from GUI

* Fix Convert
Debug logging in Initializers

* Fix Preview Tool

* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs

* lib.gui.popup_configure - Make more responsive + document

* Multiple Outputs supported in trainer
Original Model - Mask output bugfix

* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)

* Fix inference model to work properly with all models

* Fix multi-scale output for convert

* Fix clipnorm issue with distribution strategies
Edit error message on OOM

* Update plaidml losses

* Add missing file

* Disable gmsd loss for plaidnl

* PlaidML - Basic training working

* clipnorm rewriting for mixed-precision

* Inference model creation bugfixes

* Remove debug code

* Bugfix: Default clipnorm to 1.0

* Remove all mask inputs from training code

* Remove mask inputs from convert

* GUI - Analysis Tab - Docstrings

* Fix rate in totals row

* lib.gui - Only update display pages if they have focus

* Save the model on first iteration

* plaidml - Fix SSIM loss with penalized loss

* tools.alignments - Remove manual and fix jobs

* GUI - Remove case formatting on help text

* gui MultiSelect custom widget - Set default values on init

* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded

* Cascade excluded devices to GPU Stats

* Explicit GPU selection for Train and Convert

* Reduce Tensorflow Min GPU Multiprocessor Count to 4

* remove compat.v1 code from extract

* Force TF to skip mixed precision compatibility check if GPUs have been filtered

* Add notes to config for non-working AMD losses

* Rasie error if forcing extract to CPU mode

* Fix loading of legace dfl-sae weights + dfl-sae typo fix

* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations

* docs: lib.gui.display

* clipnorm amd condition check

* documentation - gui.display_analysis

* Documentation - gui.popup_configure

* Documentation - lib.logger

* Documentation - lib.model.initializers

* Documentation - lib.model.layers

* Documentation - lib.model.losses

* Documentation - lib.model.nn_blocks

* Documetation - lib.model.normalization

* Documentation - lib.model.session

* Documentation - lib.plaidml_stats

* Documentation: lib.training_data

* Documentation: lib.utils

* Documentation: plugins.train.model._base

* GUI Stats: prevent stats from using GPU

* Documentation - Original Model

* Documentation: plugins.model.trainer._base

* linting

* unit tests: initializers + losses

* unit tests: nn_blocks

* bugfix - Exclude gpu devices in train, not include

* Enable Exclude-Gpus in Extract

* Enable exclude gpus in tools

* Disallow multiple plugin types in a single model folder

* Automatically add exclude_gpus argument in for cpu backends

* Cpu backend fixes

* Relax optimizer test threshold

* Default Train settings - Set mask to Extended

* Update Extractor cli help text
Update to Python 3.8

* Fix FAN to run on CPU

* lib.plaidml_tools - typofix

* Linux installer - check for curl

* linux installer - typo fix
2020-08-12 10:36:41 +01:00

345 lines
13 KiB
Python
Executable file

#!/usr/bin python3
""" Graph functions for Display Frame of the Faceswap GUI """
import datetime
import logging
import os
import tkinter as tk
from tkinter import ttk
from math import ceil, floor
import matplotlib
# pylint: disable=wrong-import-position
matplotlib.use("TkAgg")
from matplotlib import style # noqa
from matplotlib.figure import Figure # noqa
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg, # noqa
NavigationToolbar2Tk)
from .custom_widgets import Tooltip # noqa
from .utils import get_config, get_images, LongRunningTask # noqa
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors
""" Same as default, but only including buttons we need
with custom icons and layout
From: https://stackoverflow.com/questions/12695678 """
toolitems = [t for t in NavigationToolbar2Tk.toolitems if
t[0] in ("Home", "Pan", "Zoom", "Save")]
@staticmethod
def _Button(frame, text, file, command, extension=".gif"): # pylint: disable=arguments-differ
""" Map Buttons to their own frame.
Use custom button icons, Use ttk buttons pack to the right """
iconmapping = {"home": "reload",
"filesave": "save",
"zoom_to_rect": "zoom"}
icon = iconmapping[file] if iconmapping.get(file, None) else file
img = get_images().icons[icon]
btn = ttk.Button(frame, text=text, image=img, command=command)
btn.pack(side=tk.RIGHT, padx=2)
return btn
def _init_toolbar(self):
""" Same as original but ttk widgets and standard tool-tips used. Separator added and
message label packed to the left """
xmin, xmax = self.canvas.figure.bbox.intervalx
height, width = 50, xmax-xmin
ttk.Frame.__init__(self, master=self.window, width=int(width), height=int(height))
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP)
self.update() # Make axes menu
btnframe = ttk.Frame(self)
btnframe.pack(fill=tk.X, padx=5, pady=5, side=tk.RIGHT)
for text, tooltip_text, image_file, callback in self.toolitems:
if text is None:
# Add a spacer; return value is unused.
self._Spacer()
else:
button = self._Button(btnframe, text=text, file=image_file,
command=getattr(self, callback))
if tooltip_text is not None:
Tooltip(button, text=tooltip_text, wraplength=200)
self.message = tk.StringVar(master=self)
self._message_label = ttk.Label(master=self, textvariable=self.message)
self._message_label.pack(side=tk.LEFT, padx=5)
self.pack(side=tk.BOTTOM, fill=tk.X)
class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
""" Base class for matplotlib line graphs """
def __init__(self, parent, data, ylabel):
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent)
style.use("ggplot")
self.calcs = data
self.ylabel = ylabel
self.colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
self.lines = list()
self.toolbar = None
self.fig = Figure(figsize=(4, 4), dpi=75)
self.ax1 = self.fig.add_subplot(1, 1, 1)
self.plotcanvas = FigureCanvasTkAgg(self.fig, self)
self.initiate_graph()
self.update_plot(initiate=True)
logger.debug("Initialized %s", self.__class__.__name__)
def initiate_graph(self):
""" Place the graph canvas """
logger.debug("Setting plotcanvas")
self.plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True)
self.fig.subplots_adjust(left=0.100,
bottom=0.100,
right=0.95,
top=0.95,
wspace=0.2,
hspace=0.2)
logger.debug("Set plotcanvas")
def update_plot(self, initiate=True):
""" Update the plot with incoming data """
logger.trace("Updating plot")
if initiate:
logger.debug("Initializing plot")
self.lines = list()
self.ax1.clear()
self.axes_labels_set()
logger.debug("Initialized plot")
fulldata = [item for item in self.calcs.stats.values()]
self.axes_limits_set(fulldata)
xrng = [x for x in range(self.calcs.iterations)]
keys = list(self.calcs.stats.keys())
for idx, item in enumerate(self.lines_sort(keys)):
if initiate:
self.lines.extend(self.ax1.plot(xrng, self.calcs.stats[item[0]],
label=item[1], linewidth=item[2], color=item[3]))
else:
self.lines[idx].set_data(xrng, self.calcs.stats[item[0]])
if initiate:
self.legend_place()
logger.trace("Updated plot")
def axes_labels_set(self):
""" Set the axes label and range """
logger.debug("Setting axes labels. y-label: '%s'", self.ylabel)
self.ax1.set_xlabel("Iterations")
self.ax1.set_ylabel(self.ylabel)
def axes_limits_set_default(self):
""" Set default axes limits """
logger.debug("Setting default axes ranges")
self.ax1.set_ylim(0.00, 100.0)
self.ax1.set_xlim(0, 1)
def axes_limits_set(self, data):
""" Set the axes limits """
xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1
if data:
ymin, ymax = self.axes_data_get_min_max(data)
self.ax1.set_ylim(ymin, ymax)
self.ax1.set_xlim(0, xmax)
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax)
else:
self.axes_limits_set_default()
@staticmethod
def axes_data_get_min_max(data):
""" Return the minimum and maximum values from list of lists """
ymin, ymax = list(), list()
for item in data:
dataset = list(filter(lambda x: x is not None, item))
if not dataset:
continue
ymin.append(min(dataset) * 1000)
ymax.append(max(dataset) * 1000)
ymin = floor(min(ymin)) / 1000
ymax = ceil(max(ymax)) / 1000
logger.trace("ymin: %s, ymax: %s", ymin, ymax)
return ymin, ymax
def axes_set_yscale(self, scale):
""" Set the Y-Scale to log or linear """
logger.debug("yscale: '%s'", scale)
self.ax1.set_yscale(scale)
def lines_sort(self, keys):
""" Sort the data keys into consistent order
and set line color map and line width """
logger.trace("Sorting lines")
raw_lines = list()
sorted_lines = list()
for key in sorted(keys):
title = key.replace("_", " ").title()
if key.startswith("raw"):
raw_lines.append([key, title])
else:
sorted_lines.append([key, title])
groupsize = self.lines_groupsize(raw_lines, sorted_lines)
sorted_lines = raw_lines + sorted_lines
lines = self.lines_style(sorted_lines, groupsize)
return lines
@staticmethod
def lines_groupsize(raw_lines, sorted_lines):
""" Get the number of items in each group.
If raw data isn't selected, then check the length of
remaining groups until something is found """
groupsize = 1
if raw_lines:
groupsize = len(raw_lines)
elif sorted_lines:
keys = [key[0][:key[0].find("_")] for key in sorted_lines]
distinct_keys = set(keys)
groupsize = len(keys) // len(distinct_keys)
logger.trace(groupsize)
return groupsize
def lines_style(self, lines, groupsize):
""" Set the color map and line width for each group """
logger.trace("Setting lines style")
groups = int(len(lines) / groupsize)
colours = self.lines_create_colors(groupsize, groups)
for idx, item in enumerate(lines):
linewidth = ceil((idx + 1) / groupsize)
item.extend((linewidth, colours[idx]))
return lines
def lines_create_colors(self, groupsize, groups):
""" Create the colors """
colours = list()
for i in range(1, groups + 1):
for colour in self.colourmaps[0:groupsize]:
cmap = matplotlib.cm.get_cmap(colour)
cpoint = 1 - (i / 5)
colours.append(cmap(cpoint))
logger.trace(colours)
return colours
def legend_place(self):
""" Place and format legend """
logger.debug("Placing legend")
self.ax1.legend(loc="upper right", ncol=2)
def toolbar_place(self, parent):
""" Add Graph Navigation toolbar """
logger.debug("Placing toolbar")
self.toolbar = NavigationToolbar(self.plotcanvas, parent)
self.toolbar.pack(side=tk.BOTTOM)
self.toolbar.update()
def clear(self):
""" Clear the plots from RAM """
logger.debug("Clearing graph from RAM: %s", self)
self.fig.clf()
del self.fig
class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
""" Live graph to be displayed during training. """
def __init__(self, parent, data, ylabel):
GraphBase.__init__(self, parent, data, ylabel)
self.thread = None # Thread for LongRunningTask
self.add_callback()
def add_callback(self):
""" Add the variable trace to update graph on recent button or save iteration """
get_config().tk_vars["refreshgraph"].trace("w", self.refresh)
def build(self):
""" Update the plot area with loss values """
logger.debug("Building training graph")
self.plotcanvas.draw()
logger.debug("Built training graph")
def refresh(self, *args): # pylint: disable=unused-argument
""" Read loss data and apply to graph """
refresh_var = get_config().tk_vars["refreshgraph"]
if not refresh_var.get() and self.thread is None:
return
if self.thread is None:
logger.debug("Updating plot data")
self.thread = LongRunningTask(target=self.calcs.refresh)
self.thread.start()
self.after(1000, self.refresh)
elif not self.thread.complete.is_set():
logger.debug("Graph Data not yet available")
self.after(1000, self.refresh)
else:
logger.debug("Updating plot with data from background thread")
self.calcs = self.thread.get_result() # Terminate the LongRunningTask object
self.thread = None
self.update_plot(initiate=False)
self.plotcanvas.draw()
refresh_var.set(False)
def save_fig(self, location):
""" Save the figure to file """
logger.debug("Saving graph: '%s'", location)
keys = sorted([key.replace("raw_", "") for key in self.calcs.stats.keys()
if key.startswith("raw_")])
filename = " - ".join(keys)
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(location, "{}_{}.{}".format(filename, now, "png"))
self.fig.set_size_inches(16, 9)
self.fig.savefig(filename, bbox_inches="tight", dpi=120)
print("Saved graph to {}".format(filename))
logger.debug("Saved graph: '%s'", filename)
self.resize_fig()
def resize_fig(self):
""" Resize the figure back to the canvas """
class Event(): # pylint: disable=too-few-public-methods
""" Event class that needs to be passed to plotcanvas.resize """
pass # pylint: disable=unnecessary-pass
Event.width = self.winfo_width()
Event.height = self.winfo_height()
self.plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
""" Session Graph for session pop-up """
def __init__(self, parent, data, ylabel, scale):
GraphBase.__init__(self, parent, data, ylabel)
self.scale = scale
def build(self):
""" Build the session graph """
logger.debug("Building session graph")
self.toolbar_place(self)
self.plotcanvas.draw()
logger.debug("Built session graph")
def refresh(self, data, ylabel, scale):
""" Refresh graph data """
logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale)
self.calcs = data
self.ylabel = ylabel
self.set_yscale_type(scale)
logger.debug("Refreshed session graph")
def set_yscale_type(self, scale):
""" switch the y-scale and redraw """
logger.debug("Updating scale type: '%s'", scale)
self.scale = scale
self.update_plot(initiate=True)
self.axes_set_yscale(self.scale)
self.plotcanvas.draw()
logger.debug("Updated scale type")