mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-08 11:53:26 -04:00
* Core Updates - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant - Document lib.gpu_stats and lib.sys_info - Remove call to GPUStats.is_plaidml from convert and replace with get_backend() - lib.gui.menu - typofix * Update Dependencies Bump Tensorflow Version Check * Port extraction to tf2 * Add custom import finder for loading Keras or tf.keras depending on backend * Add `tensorflow` to KerasFinder search path * Basic TF2 training running * model.initializers - docstring fix * Fix and pass tests for tf2 * Replace Keras backend tests with faceswap backend tests * Initial optimizers update * Monkey patch tf.keras optimizer * Remove custom Adam Optimizers and Memory Saving Gradients * Remove multi-gpu option. Add Distribution to cli * plugins.train.model._base: Add Mirror, Central and Default distribution strategies * Update tensorboard kwargs for tf2 * Penalized Loss - Fix for TF2 and AMD * Fix syntax for tf2.1 * requirements typo fix * Explicit None for clipnorm if using a distribution strategy * Fix penalized loss for distribution strategies * Update Dlight * typo fix * Pin to TF2.2 * setup.py - Install tensorflow from pip if not available in Conda * Add reduction options and set default for mirrored distribution strategy * Explicitly use default strategy rather than nullcontext * lib.model.backup_restore documentation * Remove mirrored strategy reduction method and default based on OS * Initial restructure - training * Remove PingPong Start model.base refactor * Model saving and resuming enabled * More tidying up of model.base * Enable backup and snapshotting * Re-enable state file Remove loss names from state file Fix print loss function Set snapshot iterations correctly * Revert original model to Keras Model structure rather than custom layer Output full model and sub model summary Change NNBlocks to callables rather than custom keras layers * Apply custom Conv2D layer * Finalize NNBlock restructure Update Dfaker blocks * Fix reloading model under a different distribution strategy * Pass command line arguments through to trainer * Remove training_opts from model and reference params directly * Tidy up model __init__ * Re-enable tensorboard logging Suppress "Model Not Compiled" warning * Fix timelapse * lib.model.nnblocks - Bugfix residual block Port dfaker bugfix original * dfl-h128 ported * DFL SAE ported * IAE Ported * dlight ported * port lightweight * realface ported * unbalanced ported * villain ported * lib.cli.args - Update Batchsize + move allow_growth to config * Remove output shape definition Get image sizes per side rather than globally * Strip mask input from encoder * Fix learn mask and output learned mask to preview * Trigger Allow Growth prior to setting strategy * Fix GUI Graphing * GUI - Display batchsize correctly + fix training graphs * Fix penalized loss * Enable mixed precision training * Update analysis displayed batch to match input * Penalized Loss - Multi-GPU Fix * Fix all losses for TF2 * Fix Reflect Padding * Allow different input size for each side of the model * Fix conv-aware initialization on reload * Switch allow_growth order * Move mixed_precision to cli * Remove distrubution strategies * Compile penalized loss sub-function into LossContainer * Bump default save interval to 250 Generate preview on first iteration but don't save Fix iterations to start at 1 instead of 0 Remove training deprecation warnings Bump some scripts.train loglevels * Add ability to refresh preview on demand on pop-up window * Enable refresh of training preview from GUI * Fix Convert Debug logging in Initializers * Fix Preview Tool * Update Legacy TF1 weights to TF2 Catch stats error on loading stats with missing logs * lib.gui.popup_configure - Make more responsive + document * Multiple Outputs supported in trainer Original Model - Mask output bugfix * Make universal inference model for convert Remove scaling from penalized mask loss (now handled at input to y_true) * Fix inference model to work properly with all models * Fix multi-scale output for convert * Fix clipnorm issue with distribution strategies Edit error message on OOM * Update plaidml losses * Add missing file * Disable gmsd loss for plaidnl * PlaidML - Basic training working * clipnorm rewriting for mixed-precision * Inference model creation bugfixes * Remove debug code * Bugfix: Default clipnorm to 1.0 * Remove all mask inputs from training code * Remove mask inputs from convert * GUI - Analysis Tab - Docstrings * Fix rate in totals row * lib.gui - Only update display pages if they have focus * Save the model on first iteration * plaidml - Fix SSIM loss with penalized loss * tools.alignments - Remove manual and fix jobs * GUI - Remove case formatting on help text * gui MultiSelect custom widget - Set default values on init * vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class cli - Add global GPU Exclude Option tools.sort - Use global GPU Exlude option for backend lib.model.session - Exclude all GPUs when running in CPU mode lib.cli.launcher - Set backend to CPU mode when all GPUs excluded * Cascade excluded devices to GPU Stats * Explicit GPU selection for Train and Convert * Reduce Tensorflow Min GPU Multiprocessor Count to 4 * remove compat.v1 code from extract * Force TF to skip mixed precision compatibility check if GPUs have been filtered * Add notes to config for non-working AMD losses * Rasie error if forcing extract to CPU mode * Fix loading of legace dfl-sae weights + dfl-sae typo fix * Remove unused requirements Update sphinx requirements Fix broken rst file locations * docs: lib.gui.display * clipnorm amd condition check * documentation - gui.display_analysis * Documentation - gui.popup_configure * Documentation - lib.logger * Documentation - lib.model.initializers * Documentation - lib.model.layers * Documentation - lib.model.losses * Documentation - lib.model.nn_blocks * Documetation - lib.model.normalization * Documentation - lib.model.session * Documentation - lib.plaidml_stats * Documentation: lib.training_data * Documentation: lib.utils * Documentation: plugins.train.model._base * GUI Stats: prevent stats from using GPU * Documentation - Original Model * Documentation: plugins.model.trainer._base * linting * unit tests: initializers + losses * unit tests: nn_blocks * bugfix - Exclude gpu devices in train, not include * Enable Exclude-Gpus in Extract * Enable exclude gpus in tools * Disallow multiple plugin types in a single model folder * Automatically add exclude_gpus argument in for cpu backends * Cpu backend fixes * Relax optimizer test threshold * Default Train settings - Set mask to Extended * Update Extractor cli help text Update to Python 3.8 * Fix FAN to run on CPU * lib.plaidml_tools - typofix * Linux installer - check for curl * linux installer - typo fix
414 lines
16 KiB
Python
414 lines
16 KiB
Python
#!/usr/bin python3
|
|
""" Main entry point to the training process of FaceSwap """
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
from threading import Lock
|
|
from time import sleep
|
|
|
|
import cv2
|
|
|
|
from lib.image import read_image
|
|
from lib.keypress import KBHit
|
|
from lib.multithreading import MultiThread
|
|
from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions)
|
|
from plugins.plugin_loader import PluginLoader
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class Train(): # pylint:disable=too-few-public-methods
|
|
""" The Faceswap Training Process.
|
|
|
|
The training process is responsible for training a model on a set of source faces and a set of
|
|
destination faces.
|
|
|
|
The training process is self contained and should not be referenced by any other scripts, so it
|
|
contains no public properties.
|
|
|
|
Parameters
|
|
----------
|
|
arguments: argparse.Namespace
|
|
The arguments to be passed to the training process as generated from Faceswap's command
|
|
line arguments
|
|
"""
|
|
def __init__(self, arguments):
|
|
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
|
|
self._args = arguments
|
|
self._timelapse = self._set_timelapse()
|
|
self._images = self._get_images()
|
|
self._gui_preview_trigger = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])),
|
|
"lib", "gui", ".cache", ".preview_trigger")
|
|
self._stop = False
|
|
self._save_now = False
|
|
self._refresh_preview = False
|
|
self._preview_buffer = dict()
|
|
self._lock = Lock()
|
|
|
|
self.trainer_name = self._args.trainer
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def _image_size(self):
|
|
""" int: The training image size. Reads the first image in the training folder and returns
|
|
the size. """
|
|
image = read_image(self._images["a"][0], raise_error=True)
|
|
size = image.shape[0]
|
|
logger.debug("Training image size: %s", size)
|
|
return size
|
|
|
|
def _set_timelapse(self):
|
|
""" Set time-lapse paths if requested.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The time-lapse keyword arguments for passing to the trainer
|
|
|
|
"""
|
|
if (not self._args.timelapse_input_a and
|
|
not self._args.timelapse_input_b and
|
|
not self._args.timelapse_output):
|
|
return None
|
|
if (not self._args.timelapse_input_a or
|
|
not self._args.timelapse_input_b or
|
|
not self._args.timelapse_output):
|
|
raise FaceswapError("To enable the timelapse, you have to supply all the parameters "
|
|
"(--timelapse-input-A, --timelapse-input-B and "
|
|
"--timelapse-output).")
|
|
|
|
timelapse_output = str(get_folder(self._args.timelapse_output))
|
|
|
|
for folder in (self._args.timelapse_input_a, self._args.timelapse_input_b):
|
|
if folder is not None and not os.path.isdir(folder):
|
|
raise FaceswapError("The Timelapse path '{}' does not exist".format(folder))
|
|
exts = [os.path.splitext(fname)[-1] for fname in os.listdir(folder)]
|
|
if not any(ext in _image_extensions for ext in exts):
|
|
raise FaceswapError("The Timelapse path '{}' does not contain any valid "
|
|
"images".format(folder))
|
|
kwargs = {"input_a": self._args.timelapse_input_a,
|
|
"input_b": self._args.timelapse_input_b,
|
|
"output": timelapse_output}
|
|
logger.debug("Timelapse enabled: %s", kwargs)
|
|
return kwargs
|
|
|
|
def _get_images(self):
|
|
""" Check the image folders exist and contains images and obtain image paths.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The image paths for each side. The key is the side, the value is the list of paths
|
|
for that side.
|
|
"""
|
|
logger.debug("Getting image paths")
|
|
images = dict()
|
|
for side in ("a", "b"):
|
|
image_dir = getattr(self._args, "input_{}".format(side))
|
|
if not os.path.isdir(image_dir):
|
|
logger.error("Error: '%s' does not exist", image_dir)
|
|
sys.exit(1)
|
|
|
|
images[side] = get_image_paths(image_dir)
|
|
if not images[side]:
|
|
logger.error("Error: '%s' contains no images", image_dir)
|
|
sys.exit(1)
|
|
|
|
logger.info("Model A Directory: %s", self._args.input_a)
|
|
logger.info("Model B Directory: %s", self._args.input_b)
|
|
logger.debug("Got image paths: %s", [(key, str(len(val)) + " images")
|
|
for key, val in images.items()])
|
|
return images
|
|
|
|
def process(self):
|
|
""" The entry point for triggering the Training Process.
|
|
|
|
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
|
|
"""
|
|
logger.debug("Starting Training Process")
|
|
logger.info("Training data directory: %s", self._args.model_dir)
|
|
thread = self._start_thread()
|
|
# from lib.queue_manager import queue_manager; queue_manager.debug_monitor(1)
|
|
err = self._monitor(thread)
|
|
self._end_thread(thread, err)
|
|
logger.debug("Completed Training Process")
|
|
|
|
def _start_thread(self):
|
|
""" Put the :func:`_training` into a background thread so we can keep control.
|
|
|
|
Returns
|
|
-------
|
|
:class:`lib.multithreading.MultiThread`
|
|
The background thread for running training
|
|
"""
|
|
logger.debug("Launching Trainer thread")
|
|
thread = MultiThread(target=self._training)
|
|
thread.start()
|
|
logger.debug("Launched Trainer thread")
|
|
return thread
|
|
|
|
def _end_thread(self, thread, err):
|
|
""" Output message and join thread back to main on termination.
|
|
|
|
Parameters
|
|
----------
|
|
thread: :class:`lib.multithreading.MultiThread`
|
|
The background training thread
|
|
err: bool
|
|
Whether an error has been detected in :func:`_monitor`
|
|
"""
|
|
logger.debug("Ending Training thread")
|
|
if err:
|
|
msg = "Error caught! Exiting..."
|
|
log = logger.critical
|
|
else:
|
|
msg = ("Exit requested! The trainer will complete its current cycle, "
|
|
"save the models and quit (This can take a couple of minutes "
|
|
"depending on your training speed).")
|
|
if not self._args.redirect_gui:
|
|
msg += " If you want to kill it now, press Ctrl + c"
|
|
log = logger.info
|
|
log(msg)
|
|
self._stop = True
|
|
thread.join()
|
|
sys.stdout.flush()
|
|
logger.debug("Ended training thread")
|
|
|
|
def _training(self):
|
|
""" The training process to be run inside a thread. """
|
|
try:
|
|
sleep(1) # Let preview instructions flush out to logger
|
|
logger.debug("Commencing Training")
|
|
logger.info("Loading data, this may take a while...")
|
|
model = self._load_model()
|
|
trainer = self._load_trainer(model)
|
|
self._run_training_cycle(model, trainer)
|
|
except KeyboardInterrupt:
|
|
try:
|
|
logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting")
|
|
model.save()
|
|
trainer.clear_tensorboard()
|
|
except KeyboardInterrupt:
|
|
logger.info("Saving model weights has been cancelled!")
|
|
sys.exit(0)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
def _load_model(self):
|
|
""" Load the model requested for training.
|
|
|
|
Returns
|
|
-------
|
|
:file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
"""
|
|
logger.debug("Loading Model")
|
|
model_dir = str(get_folder(self._args.model_dir))
|
|
model = PluginLoader.get_model(self.trainer_name)(
|
|
model_dir,
|
|
self._args,
|
|
training_image_size=self._image_size,
|
|
predict=False)
|
|
model.build()
|
|
logger.debug("Loaded Model")
|
|
return model
|
|
|
|
def _load_trainer(self, model):
|
|
""" Load the trainer requested for training.
|
|
|
|
Parameters
|
|
----------
|
|
model: :file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
|
|
Returns
|
|
-------
|
|
:file:`plugins.train.trainer` plugin
|
|
The requested model trainer plugin
|
|
"""
|
|
logger.debug("Loading Trainer")
|
|
trainer = PluginLoader.get_trainer(model.trainer)
|
|
trainer = trainer(model,
|
|
self._images,
|
|
self._args.batch_size,
|
|
self._args.configfile)
|
|
logger.debug("Loaded Trainer")
|
|
return trainer
|
|
|
|
def _run_training_cycle(self, model, trainer):
|
|
""" Perform the training cycle.
|
|
|
|
Handles the background training, updating previews/time-lapse on each save interval,
|
|
and saving the model.
|
|
|
|
Parameters
|
|
----------
|
|
model: :file:`plugins.train.model` plugin
|
|
The requested model plugin
|
|
trainer: :file:`plugins.train.trainer` plugin
|
|
The requested model trainer plugin
|
|
"""
|
|
logger.debug("Running Training Cycle")
|
|
if self._args.write_image or self._args.redirect_gui or self._args.preview:
|
|
display_func = self._show
|
|
else:
|
|
display_func = None
|
|
|
|
for iteration in range(1, self._args.iterations + 1):
|
|
logger.trace("Training iteration: %s", iteration)
|
|
save_iteration = iteration % self._args.save_interval == 0 or iteration == 1
|
|
|
|
if save_iteration or self._save_now or self._refresh_preview:
|
|
viewer = display_func
|
|
else:
|
|
viewer = None
|
|
timelapse = self._timelapse if save_iteration else None
|
|
trainer.train_one_step(viewer, timelapse)
|
|
if self._stop:
|
|
logger.debug("Stop received. Terminating")
|
|
break
|
|
|
|
if self._refresh_preview and viewer is not None:
|
|
if self._args.redirect_gui:
|
|
print("\n")
|
|
logger.info("[Preview Updated]")
|
|
logger.debug("Removing gui trigger file: %s", self._gui_preview_trigger)
|
|
os.remove(self._gui_preview_trigger)
|
|
self._refresh_preview = False
|
|
|
|
if save_iteration:
|
|
logger.debug("Save Iteration: (iteration: %s", iteration)
|
|
model.save()
|
|
elif self._save_now:
|
|
logger.debug("Save Requested: (iteration: %s", iteration)
|
|
model.save()
|
|
self._save_now = False
|
|
logger.debug("Training cycle complete")
|
|
model.save()
|
|
trainer.clear_tensorboard()
|
|
self._stop = True
|
|
|
|
def _monitor(self, thread):
|
|
""" Monitor the background :func:`_training` thread for key presses and errors.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
``True`` if there has been an error in the background thread otherwise ``False``
|
|
"""
|
|
is_preview = self._args.preview
|
|
preview_trigger_set = False
|
|
logger.debug("Launching Monitor")
|
|
logger.info("===================================================")
|
|
logger.info(" Starting")
|
|
if is_preview:
|
|
logger.info(" Using live preview")
|
|
logger.info(" Press '%s' to save and quit",
|
|
"Stop" if self._args.redirect_gui or self._args.colab else "ENTER")
|
|
if not self._args.redirect_gui and not self._args.colab:
|
|
logger.info(" Press 'S' to save model weights immediately")
|
|
logger.info("===================================================")
|
|
|
|
keypress = KBHit(is_gui=self._args.redirect_gui)
|
|
err = False
|
|
while True:
|
|
try:
|
|
if is_preview:
|
|
with self._lock:
|
|
for name, image in self._preview_buffer.items():
|
|
cv2.imshow(name, image) # pylint: disable=no-member
|
|
cv_key = cv2.waitKey(1000) # pylint: disable=no-member
|
|
else:
|
|
cv_key = None
|
|
|
|
if thread.has_error:
|
|
logger.debug("Thread error detected")
|
|
err = True
|
|
break
|
|
if self._stop:
|
|
logger.debug("Stop received")
|
|
break
|
|
|
|
# Preview Monitor
|
|
if is_preview and (cv_key == ord("\n") or cv_key == ord("\r")):
|
|
logger.debug("Exit requested")
|
|
break
|
|
if is_preview and cv_key == ord("s"):
|
|
print("\n")
|
|
logger.info("Save requested")
|
|
self._save_now = True
|
|
if is_preview and cv_key == ord("r"):
|
|
print("\n")
|
|
logger.info("Refresh preview requested")
|
|
self._refresh_preview = True
|
|
|
|
# Console Monitor
|
|
if keypress.kbhit():
|
|
console_key = keypress.getch()
|
|
if console_key in ("\n", "\r"):
|
|
logger.debug("Exit requested")
|
|
break
|
|
if console_key in ("s", "S"):
|
|
logger.info("Save requested")
|
|
self._save_now = True
|
|
|
|
# GUI Preview trigger update monitor
|
|
if self._args.redirect_gui:
|
|
if not preview_trigger_set and os.path.isfile(self._gui_preview_trigger):
|
|
print("\n")
|
|
logger.info("Refresh preview requested")
|
|
self._refresh_preview = True
|
|
preview_trigger_set = True
|
|
|
|
if preview_trigger_set and not self._refresh_preview:
|
|
logger.debug("Resetting GUI preview trigger")
|
|
preview_trigger_set = False
|
|
|
|
sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.debug("Keyboard Interrupt received")
|
|
break
|
|
keypress.set_normal_term()
|
|
logger.debug("Closed Monitor")
|
|
return err
|
|
|
|
def _show(self, image, name=""):
|
|
""" Generate the preview and write preview file output.
|
|
|
|
Handles the output and display of preview images.
|
|
|
|
Parameters
|
|
----------
|
|
image: :class:`numpy.ndarray`
|
|
The preview image to be displayed and/or written out
|
|
name: str, optional
|
|
The name of the image for saving or display purposes. If an empty string is passed
|
|
then it will automatically be names. Default: ""
|
|
"""
|
|
logger.debug("Updating preview: (name: %s)", name)
|
|
try:
|
|
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
|
|
if self._args.write_image:
|
|
logger.debug("Saving preview to disk")
|
|
img = "training_preview.jpg"
|
|
imgfile = os.path.join(scriptpath, img)
|
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
|
logger.debug("Saved preview to: '%s'", img)
|
|
if self._args.redirect_gui:
|
|
logger.debug("Generating preview for GUI")
|
|
img = ".gui_training_preview.jpg"
|
|
imgfile = os.path.join(scriptpath, "lib", "gui",
|
|
".cache", "preview", img)
|
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
|
logger.debug("Generated preview for GUI: '%s'", img)
|
|
if self._args.preview:
|
|
logger.debug("Generating preview for display: '%s'", name)
|
|
with self._lock:
|
|
self._preview_buffer[name] = image
|
|
logger.debug("Generated preview for display: '%s'", name)
|
|
except Exception as err:
|
|
logging.error("could not preview sample")
|
|
raise err
|
|
logger.debug("Updated preview: (name: %s)", name)
|