1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 11:53:26 -04:00
faceswap/scripts/train.py
torzdf d8557c1970
Faceswap 2.0 (#1045)
* Core Updates
    - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
    - Document lib.gpu_stats and lib.sys_info
    - Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
    - lib.gui.menu - typofix

* Update Dependencies
Bump Tensorflow Version Check

* Port extraction to tf2

* Add custom import finder for loading Keras or tf.keras depending on backend

* Add `tensorflow` to KerasFinder search path

* Basic TF2 training running

* model.initializers - docstring fix

* Fix and pass tests for tf2

* Replace Keras backend tests with faceswap backend tests

* Initial optimizers update

* Monkey patch tf.keras optimizer

* Remove custom Adam Optimizers and Memory Saving Gradients

* Remove multi-gpu option. Add Distribution to cli

* plugins.train.model._base: Add Mirror, Central and Default distribution strategies

* Update tensorboard kwargs for tf2

* Penalized Loss - Fix for TF2 and AMD

* Fix syntax for tf2.1

* requirements typo fix

* Explicit None for clipnorm if using a distribution strategy

* Fix penalized loss for distribution strategies

* Update Dlight

* typo fix

* Pin to TF2.2

* setup.py - Install tensorflow from pip if not available in Conda

* Add reduction options and set default for mirrored distribution strategy

* Explicitly use default strategy rather than nullcontext

* lib.model.backup_restore documentation

* Remove mirrored strategy reduction method and default based on OS

* Initial restructure - training

* Remove PingPong
Start model.base refactor

* Model saving and resuming enabled

* More tidying up of model.base

* Enable backup and snapshotting

* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly

* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers

* Apply custom Conv2D layer

* Finalize NNBlock restructure
Update Dfaker blocks

* Fix reloading model under a different distribution strategy

* Pass command line arguments through to trainer

* Remove training_opts from model and reference params directly

* Tidy up model __init__

* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning

* Fix timelapse

* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original

* dfl-h128 ported

* DFL SAE ported

* IAE Ported

* dlight ported

* port lightweight

* realface ported

* unbalanced ported

* villain ported

* lib.cli.args - Update Batchsize + move allow_growth to config

* Remove output shape definition
Get image sizes per side rather than globally

* Strip mask input from encoder

* Fix learn mask and output learned mask to preview

* Trigger Allow Growth prior to setting strategy

* Fix GUI Graphing

* GUI - Display batchsize correctly + fix training graphs

* Fix penalized loss

* Enable mixed precision training

* Update analysis displayed batch to match input

* Penalized Loss - Multi-GPU Fix

* Fix all losses for TF2

* Fix Reflect Padding

* Allow different input size for each side of the model

* Fix conv-aware initialization on reload

* Switch allow_growth order

* Move mixed_precision to cli

* Remove distrubution strategies

* Compile penalized loss sub-function into LossContainer

* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels

* Add ability to refresh preview on demand on pop-up window

* Enable refresh of training preview from GUI

* Fix Convert
Debug logging in Initializers

* Fix Preview Tool

* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs

* lib.gui.popup_configure - Make more responsive + document

* Multiple Outputs supported in trainer
Original Model - Mask output bugfix

* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)

* Fix inference model to work properly with all models

* Fix multi-scale output for convert

* Fix clipnorm issue with distribution strategies
Edit error message on OOM

* Update plaidml losses

* Add missing file

* Disable gmsd loss for plaidnl

* PlaidML - Basic training working

* clipnorm rewriting for mixed-precision

* Inference model creation bugfixes

* Remove debug code

* Bugfix: Default clipnorm to 1.0

* Remove all mask inputs from training code

* Remove mask inputs from convert

* GUI - Analysis Tab - Docstrings

* Fix rate in totals row

* lib.gui - Only update display pages if they have focus

* Save the model on first iteration

* plaidml - Fix SSIM loss with penalized loss

* tools.alignments - Remove manual and fix jobs

* GUI - Remove case formatting on help text

* gui MultiSelect custom widget - Set default values on init

* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded

* Cascade excluded devices to GPU Stats

* Explicit GPU selection for Train and Convert

* Reduce Tensorflow Min GPU Multiprocessor Count to 4

* remove compat.v1 code from extract

* Force TF to skip mixed precision compatibility check if GPUs have been filtered

* Add notes to config for non-working AMD losses

* Rasie error if forcing extract to CPU mode

* Fix loading of legace dfl-sae weights + dfl-sae typo fix

* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations

* docs: lib.gui.display

* clipnorm amd condition check

* documentation - gui.display_analysis

* Documentation - gui.popup_configure

* Documentation - lib.logger

* Documentation - lib.model.initializers

* Documentation - lib.model.layers

* Documentation - lib.model.losses

* Documentation - lib.model.nn_blocks

* Documetation - lib.model.normalization

* Documentation - lib.model.session

* Documentation - lib.plaidml_stats

* Documentation: lib.training_data

* Documentation: lib.utils

* Documentation: plugins.train.model._base

* GUI Stats: prevent stats from using GPU

* Documentation - Original Model

* Documentation: plugins.model.trainer._base

* linting

* unit tests: initializers + losses

* unit tests: nn_blocks

* bugfix - Exclude gpu devices in train, not include

* Enable Exclude-Gpus in Extract

* Enable exclude gpus in tools

* Disallow multiple plugin types in a single model folder

* Automatically add exclude_gpus argument in for cpu backends

* Cpu backend fixes

* Relax optimizer test threshold

* Default Train settings - Set mask to Extended

* Update Extractor cli help text
Update to Python 3.8

* Fix FAN to run on CPU

* lib.plaidml_tools - typofix

* Linux installer - check for curl

* linux installer - typo fix
2020-08-12 10:36:41 +01:00

414 lines
16 KiB
Python

#!/usr/bin python3
""" Main entry point to the training process of FaceSwap """
import logging
import os
import sys
from threading import Lock
from time import sleep
import cv2
from lib.image import read_image
from lib.keypress import KBHit
from lib.multithreading import MultiThread
from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions)
from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Train(): # pylint:disable=too-few-public-methods
""" The Faceswap Training Process.
The training process is responsible for training a model on a set of source faces and a set of
destination faces.
The training process is self contained and should not be referenced by any other scripts, so it
contains no public properties.
Parameters
----------
arguments: argparse.Namespace
The arguments to be passed to the training process as generated from Faceswap's command
line arguments
"""
def __init__(self, arguments):
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
self._args = arguments
self._timelapse = self._set_timelapse()
self._images = self._get_images()
self._gui_preview_trigger = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])),
"lib", "gui", ".cache", ".preview_trigger")
self._stop = False
self._save_now = False
self._refresh_preview = False
self._preview_buffer = dict()
self._lock = Lock()
self.trainer_name = self._args.trainer
logger.debug("Initialized %s", self.__class__.__name__)
@property
def _image_size(self):
""" int: The training image size. Reads the first image in the training folder and returns
the size. """
image = read_image(self._images["a"][0], raise_error=True)
size = image.shape[0]
logger.debug("Training image size: %s", size)
return size
def _set_timelapse(self):
""" Set time-lapse paths if requested.
Returns
-------
dict
The time-lapse keyword arguments for passing to the trainer
"""
if (not self._args.timelapse_input_a and
not self._args.timelapse_input_b and
not self._args.timelapse_output):
return None
if (not self._args.timelapse_input_a or
not self._args.timelapse_input_b or
not self._args.timelapse_output):
raise FaceswapError("To enable the timelapse, you have to supply all the parameters "
"(--timelapse-input-A, --timelapse-input-B and "
"--timelapse-output).")
timelapse_output = str(get_folder(self._args.timelapse_output))
for folder in (self._args.timelapse_input_a, self._args.timelapse_input_b):
if folder is not None and not os.path.isdir(folder):
raise FaceswapError("The Timelapse path '{}' does not exist".format(folder))
exts = [os.path.splitext(fname)[-1] for fname in os.listdir(folder)]
if not any(ext in _image_extensions for ext in exts):
raise FaceswapError("The Timelapse path '{}' does not contain any valid "
"images".format(folder))
kwargs = {"input_a": self._args.timelapse_input_a,
"input_b": self._args.timelapse_input_b,
"output": timelapse_output}
logger.debug("Timelapse enabled: %s", kwargs)
return kwargs
def _get_images(self):
""" Check the image folders exist and contains images and obtain image paths.
Returns
-------
dict
The image paths for each side. The key is the side, the value is the list of paths
for that side.
"""
logger.debug("Getting image paths")
images = dict()
for side in ("a", "b"):
image_dir = getattr(self._args, "input_{}".format(side))
if not os.path.isdir(image_dir):
logger.error("Error: '%s' does not exist", image_dir)
sys.exit(1)
images[side] = get_image_paths(image_dir)
if not images[side]:
logger.error("Error: '%s' contains no images", image_dir)
sys.exit(1)
logger.info("Model A Directory: %s", self._args.input_a)
logger.info("Model B Directory: %s", self._args.input_b)
logger.debug("Got image paths: %s", [(key, str(len(val)) + " images")
for key, val in images.items()])
return images
def process(self):
""" The entry point for triggering the Training Process.
Should only be called from :class:`lib.cli.launcher.ScriptExecutor`
"""
logger.debug("Starting Training Process")
logger.info("Training data directory: %s", self._args.model_dir)
thread = self._start_thread()
# from lib.queue_manager import queue_manager; queue_manager.debug_monitor(1)
err = self._monitor(thread)
self._end_thread(thread, err)
logger.debug("Completed Training Process")
def _start_thread(self):
""" Put the :func:`_training` into a background thread so we can keep control.
Returns
-------
:class:`lib.multithreading.MultiThread`
The background thread for running training
"""
logger.debug("Launching Trainer thread")
thread = MultiThread(target=self._training)
thread.start()
logger.debug("Launched Trainer thread")
return thread
def _end_thread(self, thread, err):
""" Output message and join thread back to main on termination.
Parameters
----------
thread: :class:`lib.multithreading.MultiThread`
The background training thread
err: bool
Whether an error has been detected in :func:`_monitor`
"""
logger.debug("Ending Training thread")
if err:
msg = "Error caught! Exiting..."
log = logger.critical
else:
msg = ("Exit requested! The trainer will complete its current cycle, "
"save the models and quit (This can take a couple of minutes "
"depending on your training speed).")
if not self._args.redirect_gui:
msg += " If you want to kill it now, press Ctrl + c"
log = logger.info
log(msg)
self._stop = True
thread.join()
sys.stdout.flush()
logger.debug("Ended training thread")
def _training(self):
""" The training process to be run inside a thread. """
try:
sleep(1) # Let preview instructions flush out to logger
logger.debug("Commencing Training")
logger.info("Loading data, this may take a while...")
model = self._load_model()
trainer = self._load_trainer(model)
self._run_training_cycle(model, trainer)
except KeyboardInterrupt:
try:
logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting")
model.save()
trainer.clear_tensorboard()
except KeyboardInterrupt:
logger.info("Saving model weights has been cancelled!")
sys.exit(0)
except Exception as err:
raise err
def _load_model(self):
""" Load the model requested for training.
Returns
-------
:file:`plugins.train.model` plugin
The requested model plugin
"""
logger.debug("Loading Model")
model_dir = str(get_folder(self._args.model_dir))
model = PluginLoader.get_model(self.trainer_name)(
model_dir,
self._args,
training_image_size=self._image_size,
predict=False)
model.build()
logger.debug("Loaded Model")
return model
def _load_trainer(self, model):
""" Load the trainer requested for training.
Parameters
----------
model: :file:`plugins.train.model` plugin
The requested model plugin
Returns
-------
:file:`plugins.train.trainer` plugin
The requested model trainer plugin
"""
logger.debug("Loading Trainer")
trainer = PluginLoader.get_trainer(model.trainer)
trainer = trainer(model,
self._images,
self._args.batch_size,
self._args.configfile)
logger.debug("Loaded Trainer")
return trainer
def _run_training_cycle(self, model, trainer):
""" Perform the training cycle.
Handles the background training, updating previews/time-lapse on each save interval,
and saving the model.
Parameters
----------
model: :file:`plugins.train.model` plugin
The requested model plugin
trainer: :file:`plugins.train.trainer` plugin
The requested model trainer plugin
"""
logger.debug("Running Training Cycle")
if self._args.write_image or self._args.redirect_gui or self._args.preview:
display_func = self._show
else:
display_func = None
for iteration in range(1, self._args.iterations + 1):
logger.trace("Training iteration: %s", iteration)
save_iteration = iteration % self._args.save_interval == 0 or iteration == 1
if save_iteration or self._save_now or self._refresh_preview:
viewer = display_func
else:
viewer = None
timelapse = self._timelapse if save_iteration else None
trainer.train_one_step(viewer, timelapse)
if self._stop:
logger.debug("Stop received. Terminating")
break
if self._refresh_preview and viewer is not None:
if self._args.redirect_gui:
print("\n")
logger.info("[Preview Updated]")
logger.debug("Removing gui trigger file: %s", self._gui_preview_trigger)
os.remove(self._gui_preview_trigger)
self._refresh_preview = False
if save_iteration:
logger.debug("Save Iteration: (iteration: %s", iteration)
model.save()
elif self._save_now:
logger.debug("Save Requested: (iteration: %s", iteration)
model.save()
self._save_now = False
logger.debug("Training cycle complete")
model.save()
trainer.clear_tensorboard()
self._stop = True
def _monitor(self, thread):
""" Monitor the background :func:`_training` thread for key presses and errors.
Returns
-------
bool
``True`` if there has been an error in the background thread otherwise ``False``
"""
is_preview = self._args.preview
preview_trigger_set = False
logger.debug("Launching Monitor")
logger.info("===================================================")
logger.info(" Starting")
if is_preview:
logger.info(" Using live preview")
logger.info(" Press '%s' to save and quit",
"Stop" if self._args.redirect_gui or self._args.colab else "ENTER")
if not self._args.redirect_gui and not self._args.colab:
logger.info(" Press 'S' to save model weights immediately")
logger.info("===================================================")
keypress = KBHit(is_gui=self._args.redirect_gui)
err = False
while True:
try:
if is_preview:
with self._lock:
for name, image in self._preview_buffer.items():
cv2.imshow(name, image) # pylint: disable=no-member
cv_key = cv2.waitKey(1000) # pylint: disable=no-member
else:
cv_key = None
if thread.has_error:
logger.debug("Thread error detected")
err = True
break
if self._stop:
logger.debug("Stop received")
break
# Preview Monitor
if is_preview and (cv_key == ord("\n") or cv_key == ord("\r")):
logger.debug("Exit requested")
break
if is_preview and cv_key == ord("s"):
print("\n")
logger.info("Save requested")
self._save_now = True
if is_preview and cv_key == ord("r"):
print("\n")
logger.info("Refresh preview requested")
self._refresh_preview = True
# Console Monitor
if keypress.kbhit():
console_key = keypress.getch()
if console_key in ("\n", "\r"):
logger.debug("Exit requested")
break
if console_key in ("s", "S"):
logger.info("Save requested")
self._save_now = True
# GUI Preview trigger update monitor
if self._args.redirect_gui:
if not preview_trigger_set and os.path.isfile(self._gui_preview_trigger):
print("\n")
logger.info("Refresh preview requested")
self._refresh_preview = True
preview_trigger_set = True
if preview_trigger_set and not self._refresh_preview:
logger.debug("Resetting GUI preview trigger")
preview_trigger_set = False
sleep(1)
except KeyboardInterrupt:
logger.debug("Keyboard Interrupt received")
break
keypress.set_normal_term()
logger.debug("Closed Monitor")
return err
def _show(self, image, name=""):
""" Generate the preview and write preview file output.
Handles the output and display of preview images.
Parameters
----------
image: :class:`numpy.ndarray`
The preview image to be displayed and/or written out
name: str, optional
The name of the image for saving or display purposes. If an empty string is passed
then it will automatically be names. Default: ""
"""
logger.debug("Updating preview: (name: %s)", name)
try:
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
if self._args.write_image:
logger.debug("Saving preview to disk")
img = "training_preview.jpg"
imgfile = os.path.join(scriptpath, img)
cv2.imwrite(imgfile, image) # pylint: disable=no-member
logger.debug("Saved preview to: '%s'", img)
if self._args.redirect_gui:
logger.debug("Generating preview for GUI")
img = ".gui_training_preview.jpg"
imgfile = os.path.join(scriptpath, "lib", "gui",
".cache", "preview", img)
cv2.imwrite(imgfile, image) # pylint: disable=no-member
logger.debug("Generated preview for GUI: '%s'", img)
if self._args.preview:
logger.debug("Generating preview for display: '%s'", name)
with self._lock:
self._preview_buffer[name] = image
logger.debug("Generated preview for display: '%s'", name)
except Exception as err:
logging.error("could not preview sample")
raise err
logger.debug("Updated preview: (name: %s)", name)