mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* model_refactor (#571) * original model to new structure * IAE model to new structure * OriginalHiRes to new structure * Fix trainer for different resolutions * Initial config implementation * Configparse library added * improved training data loader * dfaker model working * Add logging to training functions * Non blocking input for cli training * Add error handling to threads. Add non-mp queues to queue_handler * Improved Model Building and NNMeta * refactor lib/models * training refactor. DFL H128 model Implementation * Dfaker - use hashes * Move timelapse. Remove perceptual loss arg * Update INSTALL.md. Add logger formatting. Update Dfaker training * DFL h128 partially ported * Add mask to dfaker (#573) * Remove old models. Add mask to dfaker * dfl mask. Make masks selectable in config (#575) * DFL H128 Mask. Mask type selectable in config. * remove gan_v2_2 * Creating Input Size config for models Creating Input Size config for models Will be used downstream in converters. Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes) * Add mask loss options to config * MTCNN options to config.ini. Remove GAN config. Update USAGE.md * Add sliders for numerical values in GUI * Add config plugins menu to gui. Validate config * Only backup model if loss has dropped. Get training working again * bugfixes * Standardise loss printing * GUI idle cpu fixes. Graph loss fix. * mutli-gpu logging bugfix * Merge branch 'staging' into train_refactor * backup state file * Crash protection: Only backup if both total losses have dropped * Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes) * Load and save model structure with weights * Slight code update * Improve config loader. Add subpixel opt to all models. Config to state * Show samples... wrong input * Remove AE topology. Add input/output shapes to State * Port original_villain (birb/VillainGuy) model to faceswap * Add plugin info to GUI config pages * Load input shape from state. IAE Config options. * Fix transform_kwargs. Coverage to ratio. Bugfix mask detection * Suppress keras userwarnings. Automate zoom. Coverage_ratio to model def. * Consolidation of converters & refactor (#574) * Consolidation of converters & refactor Initial Upload of alpha Items - consolidate convert_mased & convert_adjust into one converter -add average color adjust to convert_masked -allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size -allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size -eliminate redundant type conversions to avoid multiple round-off errors -refactor loops for vectorization/speed -reorganize for clarity & style changes TODO - bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now - issues with mask border giving black ring at zero erosion .. investigate - remove GAN ?? - test enlargment factors of umeyama standard face .. match to coverage factor - make enlargment factor a model parameter - remove convert_adjusted and referencing code when finished * Update Convert_Masked.py default blur size of 2 to match original... description of enlargement tests breakout matrxi scaling into def * Enlargment scale as a cli parameter * Update cli.py * dynamic interpolation algorithm Compute x & y scale factors from the affine matrix on the fly by QR decomp. Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image * input size input size from config * fix issues with <1.0 erosion * Update convert.py * Update Convert_Adjust.py more work on the way to merginf * Clean up help note on sharpen * cleanup seamless * Delete Convert_Adjust.py * Update umeyama.py * Update training_data.py * swapping * segmentation stub * changes to convert.str * Update masked.py * Backwards compatibility fix for models Get converter running * Convert: Move masks to class. bugfix blur_size some linting * mask fix * convert fixes - missing facehull_rect re-added - coverage to % - corrected coverage logic - cleanup of gui option ordering * Update cli.py * default for blur * Update masked.py * added preliminary low_mem version of OriginalHighRes model plugin * Code cleanup, minor fixes * Update masked.py * Update masked.py * Add dfl mask to convert * histogram fix & seamless location * update * revert * bugfix: Load actual configuration in gui * Standardize nn_blocks * Update cli.py * Minor code amends * Fix Original HiRes model * Add masks to preview output for mask trainers refactor trainer.__base.py * Masked trainers converter support * convert bugfix * Bugfix: Converter for masked (dfl/dfaker) trainers * Additional Losses (#592) * initial upload * Delete blur.py * default initializer = He instead of Glorot (#588) * Allow kernel_initializer to be overridable * Add ICNR Initializer option for upscale on all models. * Hopefully fixes RSoDs with original-highres model plugin * remove debug line * Original-HighRes model plugin Red Screen of Death fix, take #2 * Move global options to _base. Rename Villain model * clipnorm and res block biases * scale the end of res block * res block * dfaker pre-activation res * OHRES pre-activation * villain pre-activation * tabs/space in nn_blocks * fix for histogram with mask all set to zero * fix to prevent two networks with same name * GUI: Wider tooltips. Improve TQDM capture * Fix regex bug * Convert padding=48 to ratio of image size * Add size option to alignments tool extract * Pass through training image size to convert from model * Convert: Pull training coverage from model * convert: coverage, blur and erode to percent * simplify matrix scaling * ordering of sliders in train * Add matrix scaling to utils. Use interpolation in lib.aligner transform * masked.py Import get_matrix_scaling from utils * fix circular import * Update masked.py * quick fix for matrix scaling * testing thus for now * tqdm regex capture bugfix * Minor ammends * blur size cleanup * Remove coverage option from convert (Now cascades from model) * Implement convert for all model types * Add mask option and coverage option to all existing models * bugfix for model loading on convert * debug print removal * Bugfix for masks in dfl_h128 and iae * Update preview display. Add preview scaling to cli * mask notes * Delete training_data_v2.py errant file * training data variables * Fix timelapse function * Add new config items to state file for legacy purposes * Slight GUI tweak * Raise exception if problem with loaded model * Add Tensorboard support (Logs stored in model directory) * ICNR fix * loss bugfix * convert bugfix * Move ini files to config folder. Make TensorBoard optional * Fix training data for unbalanced inputs/outputs * Fix config "none" test * Keep helptext in .ini files when saving config from GUI * Remove frame_dims from alignments * Add no-flip and warp-to-landmarks cli options * Revert OHR to RC4_fix version * Fix lowmem mode on OHR model * padding to variable * Save models in parallel threads * Speed-up of res_block stability * Automated Reflection Padding * Reflect Padding as a training option Includes auto-calculation of proper padding shapes, input_shapes, output_shapes Flag included in config now * rest of reflect padding * Move TB logging to cli. Session info to state file * Add session iterations to state file * Add recent files to menu. GUI code tidy up * [GUI] Fix recent file list update issue * Add correct loss names to TensorBoard logs * Update live graph to use TensorBoard and remove animation * Fix analysis tab. GUI optimizations * Analysis Graph popup to Tensorboard Logs * [GUI] Bug fix for graphing for models with hypens in name * [GUI] Correctly split loss to tabs during training * [GUI] Add loss type selection to analysis graph * Fix store command name in recent files. Switch to correct tab on open * [GUI] Disable training graph when 'no-logs' is selected * Fix graphing race condition * rename original_hires model to unbalanced
400 lines
15 KiB
Python
400 lines
15 KiB
Python
#!/usr/bin python3
|
|
""" The script to run the extract process of faceswap """
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from tqdm import tqdm
|
|
|
|
from lib.faces_detect import DetectedFace
|
|
from lib.gpu_stats import GPUStats
|
|
from lib.multithreading import MultiThread, PoolProcess, SpawnProcess
|
|
from lib.queue_manager import queue_manager, QueueEmpty
|
|
from lib.utils import get_folder, hash_encode_image
|
|
from plugins.plugin_loader import PluginLoader
|
|
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
|
|
|
|
tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class Extract():
|
|
""" The extract process. """
|
|
def __init__(self, arguments):
|
|
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
|
|
self.args = arguments
|
|
self.output_dir = get_folder(self.args.output_dir)
|
|
logger.info("Output Directory: %s", self.args.output_dir)
|
|
self.images = Images(self.args)
|
|
self.alignments = Alignments(self.args, True, self.images.is_video)
|
|
self.plugins = Plugins(self.args)
|
|
|
|
self.post_process = PostProcess(arguments)
|
|
|
|
self.verify_output = False
|
|
self.save_interval = None
|
|
if hasattr(self.args, "save_interval"):
|
|
self.save_interval = self.args.save_interval
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def process(self):
|
|
""" Perform the extraction process """
|
|
logger.info('Starting, this may take a while...')
|
|
Utils.set_verbosity(self.args.loglevel)
|
|
# queue_manager.debug_monitor(1)
|
|
self.threaded_io("load")
|
|
save_thread = self.threaded_io("save")
|
|
self.run_extraction()
|
|
save_thread.join()
|
|
self.alignments.save()
|
|
Utils.finalize(self.images.images_found,
|
|
self.alignments.faces_count,
|
|
self.verify_output)
|
|
|
|
def threaded_io(self, task, io_args=None):
|
|
""" Perform I/O task in a background thread """
|
|
logger.debug("Threading task: (Task: '%s')", task)
|
|
io_args = tuple() if io_args is None else (io_args, )
|
|
if task == "load":
|
|
func = self.load_images
|
|
elif task == "save":
|
|
func = self.save_faces
|
|
elif task == "reload":
|
|
func = self.reload_images
|
|
io_thread = MultiThread(func, *io_args, thread_count=1)
|
|
io_thread.start()
|
|
return io_thread
|
|
|
|
def load_images(self):
|
|
""" Load the images """
|
|
logger.debug("Load Images: Start")
|
|
load_queue = queue_manager.get_queue("load")
|
|
for filename, image in self.images.load():
|
|
if load_queue.shutdown.is_set():
|
|
logger.debug("Load Queue: Stop signal received. Terminating")
|
|
break
|
|
if image is None or not image.any():
|
|
logger.warning("Unable to open image. Skipping: '%s'", filename)
|
|
continue
|
|
imagename = os.path.basename(filename)
|
|
if imagename in self.alignments.data.keys():
|
|
logger.trace("Skipping image: '%s'", filename)
|
|
continue
|
|
item = {"filename": filename,
|
|
"image": image}
|
|
load_queue.put(item)
|
|
load_queue.put("EOF")
|
|
logger.debug("Load Images: Complete")
|
|
|
|
def reload_images(self, detected_faces):
|
|
""" Reload the images and pair to detected face """
|
|
logger.debug("Reload Images: Start. Detected Faces Count: %s", len(detected_faces))
|
|
load_queue = queue_manager.get_queue("detect")
|
|
for filename, image in self.images.load():
|
|
if load_queue.shutdown.is_set():
|
|
logger.debug("Reload Queue: Stop signal received. Terminating")
|
|
break
|
|
logger.trace("Reloading image: '%s'", filename)
|
|
detect_item = detected_faces.pop(filename, None)
|
|
if not detect_item:
|
|
logger.warning("Couldn't find faces for: %s", filename)
|
|
continue
|
|
detect_item["image"] = image
|
|
load_queue.put(detect_item)
|
|
load_queue.put("EOF")
|
|
logger.debug("Reload Images: Complete")
|
|
|
|
@staticmethod
|
|
def save_faces():
|
|
""" Save the generated faces """
|
|
logger.debug("Save Faces: Start")
|
|
save_queue = queue_manager.get_queue("save")
|
|
while True:
|
|
if save_queue.shutdown.is_set():
|
|
logger.debug("Save Queue: Stop signal received. Terminating")
|
|
break
|
|
item = save_queue.get()
|
|
if item == "EOF":
|
|
break
|
|
filename, face = item
|
|
|
|
logger.trace("Saving face: '%s'", filename)
|
|
try:
|
|
with open(filename, "wb") as out_file:
|
|
out_file.write(face)
|
|
except Exception as err: # pylint: disable=broad-except
|
|
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
|
|
continue
|
|
logger.debug("Save Faces: Complete")
|
|
|
|
def run_extraction(self):
|
|
""" Run Face Detection """
|
|
save_queue = queue_manager.get_queue("save")
|
|
to_process = self.process_item_count()
|
|
frame_no = 0
|
|
size = self.args.size if hasattr(self.args, "size") else 256
|
|
align_eyes = self.args.align_eyes if hasattr(self.args, "align_eyes") else False
|
|
|
|
if self.plugins.is_parallel:
|
|
logger.debug("Using parallel processing")
|
|
self.plugins.launch_aligner()
|
|
self.plugins.launch_detector()
|
|
if not self.plugins.is_parallel:
|
|
logger.debug("Using serial processing")
|
|
self.run_detection(to_process)
|
|
self.plugins.launch_aligner()
|
|
|
|
for faces in tqdm(self.plugins.detect_faces(extract_pass="align"),
|
|
total=to_process,
|
|
file=sys.stdout,
|
|
desc="Extracting faces"):
|
|
|
|
filename = faces["filename"]
|
|
|
|
self.align_face(faces, align_eyes, size, filename)
|
|
self.post_process.do_actions(faces)
|
|
|
|
faces_count = len(faces["detected_faces"])
|
|
if faces_count == 0:
|
|
logger.verbose("No faces were detected in image: %s",
|
|
os.path.basename(filename))
|
|
|
|
if not self.verify_output and faces_count > 1:
|
|
self.verify_output = True
|
|
|
|
self.output_faces(filename, faces, save_queue)
|
|
|
|
frame_no += 1
|
|
if frame_no == self.save_interval:
|
|
self.alignments.save()
|
|
frame_no = 0
|
|
|
|
save_queue.put("EOF")
|
|
|
|
def process_item_count(self):
|
|
""" Return the number of items to be processedd """
|
|
processed = sum(os.path.basename(frame) in self.alignments.data.keys()
|
|
for frame in self.images.input_images)
|
|
logger.debug("Items already processed: %s", processed)
|
|
|
|
if processed != 0 and self.args.skip_existing:
|
|
logger.info("Skipping previously extracted frames: %s", processed)
|
|
if processed != 0 and self.args.skip_faces:
|
|
logger.info("Skipping frames with detected faces: %s", processed)
|
|
|
|
to_process = self.images.images_found - processed
|
|
logger.debug("Items to be Processed: %s", to_process)
|
|
if to_process == 0:
|
|
logger.error("No frames to process. Exiting")
|
|
queue_manager.terminate_queues()
|
|
exit(0)
|
|
return to_process
|
|
|
|
def run_detection(self, to_process):
|
|
""" Run detection only """
|
|
self.plugins.launch_detector()
|
|
detected_faces = dict()
|
|
for detected in tqdm(self.plugins.detect_faces(extract_pass="detect"),
|
|
total=to_process,
|
|
file=sys.stdout,
|
|
desc="Detecting faces"):
|
|
exception = detected.get("exception", False)
|
|
if exception:
|
|
break
|
|
|
|
del detected["image"]
|
|
filename = detected["filename"]
|
|
|
|
detected_faces[filename] = detected
|
|
|
|
self.threaded_io("reload", detected_faces)
|
|
|
|
def align_face(self, faces, align_eyes, size, filename):
|
|
""" Align the detected face and add the destination file path """
|
|
final_faces = list()
|
|
image = faces["image"]
|
|
landmarks = faces["landmarks"]
|
|
detected_faces = faces["detected_faces"]
|
|
for idx, face in enumerate(detected_faces):
|
|
detected_face = DetectedFace()
|
|
detected_face.from_dlib_rect(face, image)
|
|
detected_face.landmarksXY = landmarks[idx]
|
|
detected_face.load_aligned(image, size=size, align_eyes=align_eyes)
|
|
final_faces.append({"file_location": self.output_dir / Path(filename).stem,
|
|
"face": detected_face})
|
|
faces["detected_faces"] = final_faces
|
|
|
|
def output_faces(self, filename, faces, save_queue):
|
|
""" Output faces to save thread """
|
|
final_faces = list()
|
|
for idx, detected_face in enumerate(faces["detected_faces"]):
|
|
output_file = detected_face["file_location"]
|
|
extension = Path(filename).suffix
|
|
out_filename = "{}_{}{}".format(str(output_file), str(idx), extension)
|
|
|
|
face = detected_face["face"]
|
|
resized_face = face.aligned_face
|
|
|
|
face.hash, img = hash_encode_image(resized_face, extension)
|
|
save_queue.put((out_filename, img))
|
|
final_faces.append(face.to_alignment())
|
|
self.alignments.data[os.path.basename(filename)] = final_faces
|
|
|
|
|
|
class Plugins():
|
|
""" Detector and Aligner Plugins and queues """
|
|
def __init__(self, arguments):
|
|
logger.debug("Initializing %s", self.__class__.__name__)
|
|
self.args = arguments
|
|
self.detector = self.load_detector()
|
|
self.aligner = self.load_aligner()
|
|
self.is_parallel = self.set_parallel_processing()
|
|
|
|
self.process_detect = None
|
|
self.process_align = None
|
|
self.add_queues()
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def set_parallel_processing(self):
|
|
""" Set whether to run detect and align together or separately """
|
|
detector_vram = self.detector.vram
|
|
aligner_vram = self.aligner.vram
|
|
gpu_stats = GPUStats()
|
|
if (detector_vram == 0
|
|
or aligner_vram == 0
|
|
or gpu_stats.device_count == 0):
|
|
logger.debug("At least one of aligner or detector have no VRAM requirement. "
|
|
"Enabling parallel processing.")
|
|
return True
|
|
|
|
if hasattr(self.args, "multiprocess") and not self.args.multiprocess:
|
|
logger.info("NB: Parallel processing disabled.You may get faster "
|
|
"extraction speeds by enabling it with the -mp switch")
|
|
return False
|
|
|
|
required_vram = detector_vram + aligner_vram + 320 # 320MB buffer
|
|
stats = gpu_stats.get_card_most_free()
|
|
free_vram = int(stats["free"])
|
|
logger.verbose("%s - %sMB free of %sMB",
|
|
stats["device"],
|
|
free_vram,
|
|
int(stats["total"]))
|
|
if free_vram <= required_vram:
|
|
logger.warning("Not enough free VRAM for parallel processing. "
|
|
"Switching to serial")
|
|
return False
|
|
return True
|
|
|
|
def add_queues(self):
|
|
""" Add the required processing queues to Queue Manager """
|
|
for task in ("load", "detect", "align", "save"):
|
|
size = 0
|
|
if task == "load" or (not self.is_parallel and task == "detect"):
|
|
size = 100
|
|
queue_manager.add_queue(task, maxsize=size)
|
|
|
|
def load_detector(self):
|
|
""" Set global arguments and load detector plugin """
|
|
detector_name = self.args.detector.replace("-", "_").lower()
|
|
logger.debug("Loading Detector: '%s'", detector_name)
|
|
# Rotation
|
|
rotation = None
|
|
if hasattr(self.args, "rotate_images"):
|
|
rotation = self.args.rotate_images
|
|
|
|
detector = PluginLoader.get_detector(detector_name)(
|
|
loglevel=self.args.loglevel,
|
|
rotation=rotation)
|
|
|
|
return detector
|
|
|
|
def load_aligner(self):
|
|
""" Set global arguments and load aligner plugin """
|
|
aligner_name = self.args.aligner.replace("-", "_").lower()
|
|
logger.debug("Loading Aligner: '%s'", aligner_name)
|
|
|
|
aligner = PluginLoader.get_aligner(aligner_name)(
|
|
loglevel=self.args.loglevel)
|
|
|
|
return aligner
|
|
|
|
def launch_aligner(self):
|
|
""" Launch the face aligner """
|
|
logger.debug("Launching Aligner")
|
|
out_queue = queue_manager.get_queue("align")
|
|
kwargs = {"in_queue": queue_manager.get_queue("detect"),
|
|
"out_queue": out_queue}
|
|
|
|
self.process_align = SpawnProcess(self.aligner.run, **kwargs)
|
|
event = self.process_align.event
|
|
self.process_align.start()
|
|
|
|
# Wait for Aligner to take it's VRAM
|
|
# The first ever load of the model for FAN has reportedly taken
|
|
# up to 3-4 minutes, hence high timeout.
|
|
# TODO investigate why this is and fix if possible
|
|
for mins in reversed(range(5)):
|
|
event.wait(60)
|
|
if event.is_set():
|
|
break
|
|
if mins == 0:
|
|
raise ValueError("Error initializing Aligner")
|
|
logger.info("Waiting for Aligner... Time out in %s minutes", mins)
|
|
|
|
logger.debug("Launched Aligner")
|
|
|
|
def launch_detector(self):
|
|
""" Launch the face detector """
|
|
logger.debug("Launching Detector")
|
|
out_queue = queue_manager.get_queue("detect")
|
|
kwargs = {"in_queue": queue_manager.get_queue("load"),
|
|
"out_queue": out_queue}
|
|
|
|
mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
|
|
self.process_detect = mp_func(self.detector.run, **kwargs)
|
|
|
|
event = None
|
|
if hasattr(self.process_detect, "event"):
|
|
event = self.process_detect.event
|
|
|
|
self.process_detect.start()
|
|
|
|
if event is None:
|
|
logger.debug("Launched Detector")
|
|
return
|
|
|
|
for mins in reversed(range(5)):
|
|
event.wait(60)
|
|
if event.is_set():
|
|
break
|
|
if mins == 0:
|
|
raise ValueError("Error initializing Detector")
|
|
logger.info("Waiting for Detector... Time out in %s minutes", mins)
|
|
|
|
logger.debug("Launched Detector")
|
|
|
|
def detect_faces(self, extract_pass="detect"):
|
|
""" Detect faces from in an image """
|
|
logger.debug("Running Detection. Pass: '%s'", extract_pass)
|
|
if self.is_parallel or extract_pass == "align":
|
|
out_queue = queue_manager.get_queue("align")
|
|
if not self.is_parallel and extract_pass == "detect":
|
|
out_queue = queue_manager.get_queue("detect")
|
|
|
|
while True:
|
|
try:
|
|
faces = out_queue.get(True, 1)
|
|
if faces == "EOF":
|
|
break
|
|
if isinstance(faces, dict) and faces.get("exception"):
|
|
pid = faces["exception"][0]
|
|
t_back = faces["exception"][1].getvalue()
|
|
err = "Error in child process {}. {}".format(pid, t_back)
|
|
raise Exception(err)
|
|
except QueueEmpty:
|
|
continue
|
|
|
|
yield faces
|
|
logger.debug("Detection Complete")
|