mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-08 20:13:52 -04:00
* Add Improved AutoEncoder model. * Refactoring Model_IAE to match the new model folder structure * Add Model_IAE in plugins * Add Multi-GPU support I added multi-GPU support to the new model layout. Currently, Original is not tested (due to OOM on my 2x 4gb 970s). LowMem is not tested with the current commit due to it not being available since the new pluginloader misses it.
203 lines
8.5 KiB
Python
203 lines
8.5 KiB
Python
import cv2
|
|
import numpy
|
|
import time
|
|
|
|
from threading import Lock
|
|
from lib.utils import get_image_paths, get_folder
|
|
from lib.cli import FullPaths
|
|
from plugins.PluginLoader import PluginLoader
|
|
|
|
class TrainingProcessor(object):
|
|
arguments = None
|
|
|
|
def __init__(self, subparser, command, description='default'):
|
|
self.parse_arguments(description, subparser, command)
|
|
self.lock = Lock()
|
|
|
|
def process_arguments(self, arguments):
|
|
self.arguments = arguments
|
|
print("Model A Directory: {}".format(self.arguments.input_A))
|
|
print("Model B Directory: {}".format(self.arguments.input_B))
|
|
print("Training data directory: {}".format(self.arguments.model_dir))
|
|
|
|
self.process()
|
|
|
|
def parse_arguments(self, description, subparser, command):
|
|
parser = subparser.add_parser(
|
|
command,
|
|
help="This command trains the model for the two faces A and B.",
|
|
description=description,
|
|
epilog="Questions and feedback: \
|
|
https://github.com/deepfakes/faceswap-playground"
|
|
)
|
|
|
|
parser.add_argument('-A', '--input-A',
|
|
action=FullPaths,
|
|
dest="input_A",
|
|
default="input_A",
|
|
help="Input directory. A directory containing training images for face A.\
|
|
Defaults to 'input'")
|
|
parser.add_argument('-B', '--input-B',
|
|
action=FullPaths,
|
|
dest="input_B",
|
|
default="input_B",
|
|
help="Input directory. A directory containing training images for face B.\
|
|
Defaults to 'input'")
|
|
parser.add_argument('-m', '--model-dir',
|
|
action=FullPaths,
|
|
dest="model_dir",
|
|
default="models",
|
|
help="Model directory. This is where the training data will \
|
|
be stored. Defaults to 'model'")
|
|
parser.add_argument('-p', '--preview',
|
|
action="store_true",
|
|
dest="preview",
|
|
default=False,
|
|
help="Show preview output. If not specified, write progress \
|
|
to file.")
|
|
parser.add_argument('-v', '--verbose',
|
|
action="store_true",
|
|
dest="verbose",
|
|
default=False,
|
|
help="Show verbose output")
|
|
parser.add_argument('-s', '--save-interval',
|
|
type=int,
|
|
dest="save_interval",
|
|
default=100,
|
|
help="Sets the number of iterations before saving the model.")
|
|
parser.add_argument('-w', '--write-image',
|
|
action="store_true",
|
|
dest="write_image",
|
|
default=False,
|
|
help="Writes the training result to a file even on preview mode.")
|
|
parser.add_argument('-t', '--trainer',
|
|
type=str,
|
|
choices=PluginLoader.get_available_models(),
|
|
default=PluginLoader.get_default_model(),
|
|
help="Select which trainer to use, LowMem for cards < 2gb.")
|
|
parser.add_argument('-pl', '--use-perceptual-loss',
|
|
action="store_true",
|
|
dest="perceptual_loss",
|
|
default=False,
|
|
help="Use perceptual loss while training")
|
|
parser.add_argument('-bs', '--batch-size',
|
|
type=int,
|
|
default=64,
|
|
help="Batch size, as a power of 2 (64, 128, 256, etc)")
|
|
parser.add_argument('-ag', '--allow-growth',
|
|
action="store_true",
|
|
dest="allow_growth",
|
|
default=False,
|
|
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
|
|
parser.add_argument('-ep', '--epochs',
|
|
type=int,
|
|
default=1000000,
|
|
help="Length of training in epochs.")
|
|
parser.add_argument('-g', '--gpus',
|
|
type=int,
|
|
default=1,
|
|
help="Number of GPUs to use for training")
|
|
parser = self.add_optional_arguments(parser)
|
|
parser.set_defaults(func=self.process_arguments)
|
|
|
|
def add_optional_arguments(self, parser):
|
|
# Override this for custom arguments
|
|
return parser
|
|
|
|
def process(self):
|
|
import threading
|
|
self.stop = False
|
|
self.save_now = False
|
|
|
|
thr = threading.Thread(target=self.processThread, args=(), kwargs={})
|
|
thr.start()
|
|
|
|
if self.arguments.preview:
|
|
print('Using live preview')
|
|
while True:
|
|
try:
|
|
with self.lock:
|
|
for name, image in self.preview_buffer.items():
|
|
cv2.imshow(name, image)
|
|
|
|
key = cv2.waitKey(1000)
|
|
if key == ord('\n') or key == ord('\r'):
|
|
break
|
|
if key == ord('s'):
|
|
self.save_now = True
|
|
except KeyboardInterrupt:
|
|
break
|
|
else:
|
|
input() # TODO how to catch a specific key instead of Enter?
|
|
# there isnt a good multiplatform solution: https://stackoverflow.com/questions/3523174/raw-input-in-python-without-pressing-enter
|
|
|
|
print("Exit requested! The trainer will complete its current cycle, save the models and quit (it can take up a couple of seconds depending on your training speed). If you want to kill it now, press Ctrl + c")
|
|
self.stop = True
|
|
thr.join() # waits until thread finishes
|
|
|
|
def processThread(self):
|
|
try:
|
|
if self.arguments.allow_growth:
|
|
self.set_tf_allow_growth()
|
|
|
|
print('Loading data, this may take a while...')
|
|
# this is so that you can enter case insensitive values for trainer
|
|
trainer = self.arguments.trainer
|
|
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
|
|
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir), self.arguments.gpus)
|
|
model.load(swapped=False)
|
|
|
|
images_A = get_image_paths(self.arguments.input_A)
|
|
images_B = get_image_paths(self.arguments.input_B)
|
|
trainer = PluginLoader.get_trainer(trainer)
|
|
trainer = trainer(model, images_A, images_B, self.arguments.batch_size, self.arguments.perceptual_loss)
|
|
|
|
print('Starting. Press "Enter" to stop training and save model')
|
|
|
|
for epoch in range(0, self.arguments.epochs):
|
|
|
|
save_iteration = epoch % self.arguments.save_interval == 0
|
|
|
|
trainer.train_one_step(epoch, self.show if (save_iteration or self.save_now) else None)
|
|
|
|
if save_iteration:
|
|
model.save_weights()
|
|
|
|
if self.stop:
|
|
model.save_weights()
|
|
exit()
|
|
|
|
if self.save_now:
|
|
model.save_weights()
|
|
self.save_now = False
|
|
|
|
except KeyboardInterrupt:
|
|
try:
|
|
model.save_weights()
|
|
except KeyboardInterrupt:
|
|
print('Saving model weights has been cancelled!')
|
|
exit(0)
|
|
except Exception as e:
|
|
print(e)
|
|
exit(1)
|
|
|
|
def set_tf_allow_growth(self):
|
|
import tensorflow as tf
|
|
from keras.backend.tensorflow_backend import set_session
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True
|
|
config.gpu_options.visible_device_list="0"
|
|
set_session(tf.Session(config=config))
|
|
|
|
preview_buffer = {}
|
|
|
|
def show(self, image, name=''):
|
|
try:
|
|
if self.arguments.preview:
|
|
with self.lock:
|
|
self.preview_buffer[name] = image
|
|
elif self.arguments.write_image:
|
|
cv2.imwrite('_sample_{}.jpg'.format(name), image)
|
|
except Exception as e:
|
|
print("could not preview sample")
|
|
print(e)
|