1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/scripts/train.py
torzdf be8e235eab
Gui v3.0b (#436)
* GUI version 3 (#411)

GUI version 3.0a

* Required for Shaonlu mode (#416)

Added two modes - Original and Shaonlu.
The later requires this file to function.

* model update (#417)

New, functional Original 128 model

* OriginalHighRes 128 model update (#418)

Required for OriginalHighRes Model to function

* Add OriginalHighRes 128 update to gui branch (#421)

* Required for Shaonlu mode (#416)

Added two modes - Original and Shaonlu.
The later requires this file to function.

* model update (#417)

New, functional Original 128 model

* OriginalHighRes 128 model update (#418)

Required for OriginalHighRes Model to function

* Dev gui (#420)

* reduce singletons

* Fix tooltips and screen boundaries on popup

* Remove dpi fix. Fix context filebrowsers

* fix tools.py execution and context filebrowser bugs

* Bugfixes (#422)

* Bump matplotlib requirement. Fix polyfit. Fix TQDM on sort

* Fixed memory usage at 6GB cards. (#423)

- Switched default encoder to ORIGINAL
- Fixed memory consumption. Tested with geforce gtx 9800 ti with 6Gb; batch_size 8 no OOM or memory warnings now.

* Staging (#426)

* altered trainer (#425)

altered trainer to accommodate with model change

* Update Model.py (#424)

- Added saving state (currently only saved epoch number, to be extended in future)
- Changed saving to ThreadPoolExecutor

* Add DPI Scaling (#428)

* Add dpi scaling

* Hotfix for effmpeg. (#429)

effmpeg fixed so it works both in cli and gui.
Initial work done to add previewing feature to effmpeg (currently does nothing).
Some small spacing changes in other files to improve PEP8 conformity.

* PEP8 Linting (#430)

* pep8 linting

* Requirements version bump (#432)

* altered trainer (#425)

altered trainer to accommodate with model change

* Update Model.py (#424)

- Added saving state (currently only saved epoch number, to be extended in future)
- Changed saving to ThreadPoolExecutor

* Requirements version bump (#431)

This bumps the versions of:

    scandir
    h5py
    Keras
    opencv-python

to their latest vesions.

Virtual Environment will need to be setup again to make use of these.

* High DPI Fixes (#433)

* dpi scaling

* DPI Fixes

* Fix and improve context manager. (#434)

effmpeg tool:
Context manager for GUI fixed.

Context manager in general:
Functionality extended to allow configuring the context with both:
  command -> action
  command -> variable (cli argument) -> action

* Change epoch option to iterations

* Change epochs to iterations
2018-06-20 19:25:31 +02:00

198 lines
7.1 KiB
Python

#!/usr/bin python3
""" The script to run the training process of faceswap """
import os
import sys
import threading
import cv2
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from lib.utils import get_folder, get_image_paths, set_system_verbosity
from plugins.PluginLoader import PluginLoader
class Train(object):
""" The training process. """
def __init__(self, arguments):
self.args = arguments
self.images = self.get_images()
self.stop = False
self.save_now = False
self.preview_buffer = dict()
self.lock = threading.Lock()
# this is so that you can enter case insensitive values for trainer
trainer_name = self.args.trainer
self.trainer_name = "LowMem" if trainer_name.lower() == "lowmem" else trainer_name
def process(self):
""" Call the training process object """
print("Training data directory: {}".format(self.args.model_dir))
lvl = '0' if self.args.verbose else '2'
set_system_verbosity(lvl)
thread = self.start_thread()
if self.args.preview:
self.monitor_preview()
else:
self.monitor_console()
self.end_thread(thread)
def get_images(self):
""" Check the image dirs exist, contain images and return the image
objects """
images = []
for image_dir in [self.args.input_A, self.args.input_B]:
if not os.path.isdir(image_dir):
print('Error: {} does not exist'.format(image_dir))
exit(1)
if not os.listdir(image_dir):
print('Error: {} contains no images'.format(image_dir))
exit(1)
images.append(get_image_paths(image_dir))
print("Model A Directory: {}".format(self.args.input_A))
print("Model B Directory: {}".format(self.args.input_B))
return images
def start_thread(self):
""" Put the training process in a thread so we can keep control """
thread = threading.Thread(target=self.process_thread)
thread.start()
return thread
def end_thread(self, thread):
""" On termination output message and join thread back to main """
print("Exit requested! The trainer will complete its current cycle, "
"save the models and quit (it can take up a couple of seconds "
"depending on your training speed). If you want to kill it now, "
"press Ctrl + c")
self.stop = True
thread.join()
sys.stdout.flush()
def process_thread(self):
""" The training process to be run inside a thread """
try:
print("Loading data, this may take a while...")
if self.args.allow_growth:
self.set_tf_allow_growth()
model = self.load_model()
trainer = self.load_trainer(model)
self.run_training_cycle(model, trainer)
except KeyboardInterrupt:
try:
model.save_weights()
except KeyboardInterrupt:
print("Saving model weights has been cancelled!")
exit(0)
except Exception as err:
raise err
def load_model(self):
""" Load the model requested for training """
model_dir = get_folder(self.args.model_dir)
model = PluginLoader.get_model(self.trainer_name)(model_dir, self.args.gpus)
model.load(swapped=False)
return model
def load_trainer(self, model):
""" Load the trainer requested for training """
images_a, images_b = self.images
trainer = PluginLoader.get_trainer(self.trainer_name)
trainer = trainer(model,
images_a,
images_b,
self.args.batch_size,
self.args.perceptual_loss)
return trainer
def run_training_cycle(self, model, trainer):
""" Perform the training cycle """
for iteration in range(0, self.args.iterations):
save_iteration = iteration % self.args.save_interval == 0
viewer = self.show if save_iteration or self.save_now else None
trainer.train_one_step(iteration, viewer)
if self.stop:
break
elif save_iteration:
model.save_weights()
elif self.save_now:
model.save_weights()
self.save_now = False
model.save_weights()
self.stop = True
def monitor_preview(self):
""" Generate the preview window and wait for keyboard input """
print("Using live preview.\n"
"Press 'ENTER' on the preview window to save and quit.\n"
"Press 'S' on the preview window to save model weights "
"immediately")
while True:
try:
with self.lock:
for name, image in self.preview_buffer.items():
cv2.imshow(name, image)
key = cv2.waitKey(1000)
if key == ord("\n") or key == ord("\r"):
break
if key == ord("s"):
self.save_now = True
if self.stop:
break
except KeyboardInterrupt:
break
@staticmethod
def monitor_console():
""" Monitor the console for any input followed by enter or ctrl+c """
# TODO: how to catch a specific key instead of Enter?
# there isn't a good multiplatform solution:
# https://stackoverflow.com/questions/3523174
# TODO: Find a way to interrupt input() if the target iterations are reached.
# At the moment, setting a target iteration and using the -p flag is
# the only guaranteed way to exit the training loop on hitting target
# iterations.
print("Starting. Press 'ENTER' to stop training and save model")
try:
input()
except KeyboardInterrupt:
pass
@staticmethod
def set_tf_allow_growth():
""" Allow TensorFlow to manage VRAM growth """
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "0"
set_session(tf.Session(config=config))
def show(self, image, name=""):
""" Generate the preview and write preview file output """
try:
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
if self.args.write_image:
img = "_sample_{}.jpg".format(name)
imgfile = os.path.join(scriptpath, img)
cv2.imwrite(imgfile, image)
if self.args.redirect_gui:
img = ".gui_preview_{}.jpg".format(name)
imgfile = os.path.join(scriptpath, "lib", "gui", ".cache", "preview", img)
cv2.imwrite(imgfile, image)
if self.args.preview:
with self.lock:
self.preview_buffer[name] = image
except Exception as err:
print("could not preview sample")
raise err