1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00
faceswap/scripts/train.py
Lev Velykoivanenko 80cde77a6d Adding new tool effmpeg ("easy"-ffmpeg) with gui support. Extend gui functionality to support filetypes. Re-opening PR. (#373)
* Pre push commit.
Add filetypes support to gui through new classes in lib/cli.py
Add various new functions to tools/effmpeg.py

* Finish developing basic effmpeg functionality.
Ready for public alpha test.

* Add ffmpy to requirements.
Fix gen-vid to allow specifying a new file in GUI.
Fix extract throwing an error when supplied with a valid directory.

Add two new gui user pop interactions: save (allows you to create new
files/directories) and nothing (disables the prompt button when it's not
needed).
Improve logic and argument processing in effmpeg.

* Fix post merge bugs.
Reformat tools.py to match the new style of faceswap.py
Fix some whitespace issues.

* Fix matplotlib.use() being called after pyplot was imported.

* Fix various effmpeg bugs and add ability do terminate nested subprocess
to GUI.

effmpeg changes:
Fix get-fps not printing to terminal.
Fix mux-audio not working.
Add verbosity option. If verbose is not specified than ffmpeg output is
reduced with the -hide_banner flag.

scripts/gui.py changes:
Add ability to terminate nested subprocesses, i.e. the following type of
process tree should now be terminated safely:
gui -> command -> command-subprocess
               -> command-subprocess -> command-sub-subprocess

* Add functionality to tools/effmpeg.py, fix some docstring and print statement issues in some files.

tools/effmpeg.py:
Transpose choices now display detailed name in GUI, while in cli they can
still be entered as a number or the full command name.
Add quiet option to effmpeg that only shows critical ffmpeg errors.
Improve user input handling.

lib/cli.py; scripts/convert.py; scripts/extract.py; scripts/train.py:
Fix some line length issues and typos in docstrings, help text and print statements.
Fix some whitespace issues.

lib/cli.py:
Add filetypes to '--alignments' argument.
Change argument action to DirFullPaths where appropriate.

* Bug fixes and improvements to tools/effmpeg.py

Fix bug where duration would not be used even when end time was not set.
Add option to specify output filetype for extraction.
Enchance gen-vid to be able to generate a video from images that were zero padded to any arbitrary number, and not just 5.
Enchance gen-vid to be able to use any of the image formats that a video can be extracted into.
Improve gen-vid output video quality.
Minor code quality improvements and ffmpeg argument formatting improvements.

* Remove dependency on psutil in scripts/gui.py and various small improvements.

lib/utils.py:
Add _image_extensions and _video_extensions as global variables to make them easily portable across all of faceswap.
Fix lack of new lines between function and class declarions to conform to PEP8.
Fix some typos and line length issues in doctsrings and comments.

scripts/convert.py:
Make tqdm print to stdout.

scripts/extract.py:
Make tqdm print to stdout.
Apply workaround for occasional TqdmSynchronisationWarning being thrown.
Fix some typos and line length issues in doctsrings and comments.

scripts/fsmedia.py:
Did TODO in scripts/fsmedia.py in Faces.load_extractor(): TODO Pass extractor_name as argument
Fix lack of new lines between function and class declarions to conform to PEP8.
Fix some typos and line length issues in doctsrings and comments.
Change 2 print statements to use format() for string formatting instead of the old '%'.

scripts/gui.py:
Refactor subprocess generation and termination to remove dependency on psutil.
Fix some typos and line length issues in comments.

tools/effmpeg.py
Refactor DataItem class to use new lib/utils.py global media file extensions.
Improve ffmpeg subprocess termination handling.
2018-05-09 18:47:17 +01:00

199 lines
7 KiB
Python

#!/usr/bin python3
""" The script to run the training process of faceswap """
import os
import sys
import threading
import cv2
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from lib.utils import get_folder, get_image_paths, set_system_verbosity
from plugins.PluginLoader import PluginLoader
class Train(object):
""" The training process. """
def __init__(self, arguments):
self.args = arguments
self.images = self.get_images()
self.stop = False
self.save_now = False
self.preview_buffer = dict()
self.lock = threading.Lock()
# this is so that you can enter case insensitive values for trainer
trainer_name = self.args.trainer
self.trainer_name = "LowMem" if trainer_name.lower() == "lowmem" else trainer_name
def process(self):
""" Call the training process object """
print("Training data directory: {}".format(self.args.model_dir))
lvl = '0' if self.args.verbose else '2'
set_system_verbosity(lvl)
thread = self.start_thread()
if self.args.preview:
self.monitor_preview()
else:
self.monitor_console()
self.end_thread(thread)
def get_images(self):
""" Check the image dirs exist, contain images and return the image
objects """
images = []
for image_dir in [self.args.input_A, self.args.input_B]:
if not os.path.isdir(image_dir):
print('Error: {} does not exist'.format(image_dir))
exit(1)
if not os.listdir(image_dir):
print('Error: {} contains no images'.format(image_dir))
exit(1)
images.append(get_image_paths(image_dir))
print("Model A Directory: {}".format(self.args.input_A))
print("Model B Directory: {}".format(self.args.input_B))
return images
def start_thread(self):
""" Put the training process in a thread so we can keep control """
thread = threading.Thread(target=self.process_thread)
thread.start()
return thread
def end_thread(self, thread):
""" On termination output message and join thread back to main """
print("Exit requested! The trainer will complete its current cycle, "
"save the models and quit (it can take up a couple of seconds "
"depending on your training speed). If you want to kill it now, "
"press Ctrl + c")
self.stop = True
thread.join()
sys.stdout.flush()
def process_thread(self):
""" The training process to be run inside a thread """
try:
print("Loading data, this may take a while...")
if self.args.allow_growth:
self.set_tf_allow_growth()
model = self.load_model()
trainer = self.load_trainer(model)
self.run_training_cycle(model, trainer)
except KeyboardInterrupt:
try:
model.save_weights()
except KeyboardInterrupt:
print("Saving model weights has been cancelled!")
exit(0)
except Exception as err:
raise err
def load_model(self):
""" Load the model requested for training """
model_dir = get_folder(self.args.model_dir)
model = PluginLoader.get_model(self.trainer_name)(model_dir, self.args.gpus)
model.load(swapped=False)
return model
def load_trainer(self, model):
""" Load the trainer requested for training """
images_a, images_b = self.images
trainer = PluginLoader.get_trainer(self.trainer_name)
trainer = trainer(model,
images_a,
images_b,
self.args.batch_size,
self.args.perceptual_loss)
return trainer
def run_training_cycle(self, model, trainer):
""" Perform the training cycle """
for epoch in range(0, self.args.epochs):
save_iteration = epoch % self.args.save_interval == 0
viewer = self.show if save_iteration or self.save_now else None
trainer.train_one_step(epoch, viewer)
if self.stop:
break
elif save_iteration:
model.save_weights()
elif self.save_now:
model.save_weights()
self.save_now = False
model.save_weights()
self.stop = True
def monitor_preview(self):
""" Generate the preview window and wait for keyboard input """
print("Using live preview.\n"
"Press 'ENTER' on the preview window to save and quit.\n"
"Press 'S' on the preview window to save model weights "
"immediately")
while True:
try:
with self.lock:
for name, image in self.preview_buffer.items():
cv2.imshow(name, image)
key = cv2.waitKey(1000)
if key == ord("\n") or key == ord("\r"):
break
if key == ord("s"):
self.save_now = True
if self.stop:
break
except KeyboardInterrupt:
break
@staticmethod
def monitor_console():
""" Monitor the console for any input followed by enter or ctrl+c """
# TODO: how to catch a specific key instead of Enter?
# there isn't a good multiplatform solution:
# https://stackoverflow.com/questions/3523174
# TODO: Find a way to interrupt input() if the target iterations are reached.
# At the moment, setting a target iteration and using the -p flag is
# the only guaranteed way to exit the training loop on hitting target
# iterations.
print("Starting. Press 'ENTER' to stop training and save model")
try:
input()
except KeyboardInterrupt:
pass
@staticmethod
def set_tf_allow_growth():
""" Allow TensorFlow to manage VRAM growth """
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "0"
set_session(tf.Session(config=config))
def show(self, image, name=""):
""" Generate the preview and write preview file output """
try:
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
if self.args.write_image:
img = "_sample_{}.jpg".format(name)
imgfile = os.path.join(scriptpath, img)
cv2.imwrite(imgfile, image)
if self.args.redirect_gui:
img = ".gui_preview.png"
imgfile = os.path.join(scriptpath, img)
cv2.imwrite(imgfile, image)
elif self.args.preview:
with self.lock:
self.preview_buffer[name] = image
except Exception as err:
print("could not preview sample")
raise err