1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/scripts/extract.py
torzdf cd00859c40
model_refactor (#571) (#572)
* model_refactor (#571)

* original model to new structure

* IAE model to new structure

* OriginalHiRes to new structure

* Fix trainer for different resolutions

* Initial config implementation

* Configparse library added

* improved training data loader

* dfaker model working

* Add logging to training functions

* Non blocking input for cli training

* Add error handling to threads. Add non-mp queues to queue_handler

* Improved Model Building and NNMeta

* refactor lib/models

* training refactor. DFL H128 model Implementation

* Dfaker - use hashes

* Move timelapse. Remove perceptual loss arg

* Update INSTALL.md. Add logger formatting. Update Dfaker training

* DFL h128 partially ported

* Add mask to dfaker (#573)

* Remove old models. Add mask to dfaker

* dfl mask. Make masks selectable in config (#575)

* DFL H128 Mask. Mask type selectable in config.

* remove gan_v2_2

* Creating Input Size config for models

Creating Input Size config for models

Will be used downstream in converters.

Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes)

* Add mask loss options to config

* MTCNN options to config.ini. Remove GAN config. Update USAGE.md

* Add sliders for numerical values in GUI

* Add config plugins menu to gui. Validate config

* Only backup model if loss has dropped. Get training working again

* bugfixes

* Standardise loss printing

* GUI idle cpu fixes. Graph loss fix.

* mutli-gpu logging bugfix

* Merge branch 'staging' into train_refactor

* backup state file

* Crash protection: Only backup if both total losses have dropped

* Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)

* Load and save model structure with weights

* Slight code update

* Improve config loader. Add subpixel opt to all models. Config to state

* Show samples... wrong input

* Remove AE topology. Add input/output shapes to State

* Port original_villain (birb/VillainGuy) model to faceswap

* Add plugin info to GUI config pages

* Load input shape from state. IAE Config options.

* Fix transform_kwargs.
Coverage to ratio.
Bugfix mask detection

* Suppress keras userwarnings.
Automate zoom.
Coverage_ratio to model def.

* Consolidation of converters & refactor (#574)

* Consolidation of converters & refactor

Initial Upload of alpha

Items
- consolidate convert_mased & convert_adjust into one converter
-add average color adjust to convert_masked
-allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size
-allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size
-eliminate redundant type conversions to avoid multiple round-off errors
-refactor loops for vectorization/speed
-reorganize for clarity & style changes

TODO
- bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now
- issues with mask border giving black ring at zero erosion .. investigate
- remove GAN ??
- test enlargment factors of umeyama standard face .. match to coverage factor
- make enlargment factor a model parameter
- remove convert_adjusted and referencing code when finished

* Update Convert_Masked.py

default blur size of 2 to match original...
description of enlargement tests
breakout matrxi scaling into def

* Enlargment scale as a cli parameter

* Update cli.py

* dynamic interpolation algorithm

Compute x & y scale factors from the affine matrix on the fly by QR decomp.
Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image

* input size
input size from config

* fix issues with <1.0 erosion

* Update convert.py

* Update Convert_Adjust.py

more work on the way to merginf

* Clean up help note on sharpen

* cleanup seamless

* Delete Convert_Adjust.py

* Update umeyama.py

* Update training_data.py

* swapping

* segmentation stub

* changes to convert.str

* Update masked.py

* Backwards compatibility fix for models
Get converter running

* Convert:
Move masks to class.
bugfix blur_size
some linting

* mask fix

* convert fixes

- missing facehull_rect re-added
- coverage to %
- corrected coverage logic
- cleanup of gui option ordering

* Update cli.py

* default for blur

* Update masked.py

* added preliminary low_mem version of OriginalHighRes model plugin

* Code cleanup, minor fixes

* Update masked.py

* Update masked.py

* Add dfl mask to convert

* histogram fix & seamless location

* update

* revert

* bugfix: Load actual configuration in gui

* Standardize nn_blocks

* Update cli.py

* Minor code amends

* Fix Original HiRes model

* Add masks to preview output for mask trainers
refactor trainer.__base.py

* Masked trainers converter support

* convert bugfix

* Bugfix: Converter for masked (dfl/dfaker) trainers

* Additional Losses (#592)

* initial upload

* Delete blur.py

* default initializer = He instead of Glorot (#588)

* Allow kernel_initializer to be overridable

* Add ICNR Initializer option for upscale on all models.

* Hopefully fixes RSoDs with original-highres model plugin

* remove debug line

* Original-HighRes model plugin Red Screen of Death fix, take #2

* Move global options to _base. Rename Villain model

* clipnorm and res block biases

* scale the end of res block

* res block

* dfaker pre-activation res

* OHRES pre-activation

* villain pre-activation

* tabs/space in nn_blocks

* fix for histogram with mask all set to zero

* fix to prevent two networks with same name

* GUI: Wider tooltips. Improve TQDM capture

* Fix regex bug

* Convert padding=48 to ratio of image size

* Add size option to alignments tool extract

* Pass through training image size to convert from model

* Convert: Pull training coverage from model

* convert: coverage, blur and erode to percent

* simplify matrix scaling

* ordering of sliders in train

* Add matrix scaling to utils. Use interpolation in lib.aligner transform

* masked.py Import get_matrix_scaling from utils

* fix circular import

* Update masked.py

* quick fix for matrix scaling

* testing thus for now

* tqdm regex capture bugfix

* Minor ammends

* blur size cleanup

* Remove coverage option from convert (Now cascades from model)

* Implement convert for all model types

* Add mask option and coverage option to all existing models

* bugfix for model loading on convert

* debug print removal

* Bugfix for masks in dfl_h128 and iae

* Update preview display. Add preview scaling to cli

* mask notes

* Delete training_data_v2.py

errant file

* training data variables

* Fix timelapse function

* Add new config items to state file for legacy purposes

* Slight GUI tweak

* Raise exception if problem with loaded model

* Add Tensorboard support (Logs stored in model directory)

* ICNR fix

* loss bugfix

* convert bugfix

* Move ini files to config folder. Make TensorBoard optional

* Fix training data for unbalanced inputs/outputs

* Fix config "none" test

* Keep helptext in .ini files when saving config from GUI

* Remove frame_dims from alignments

* Add no-flip and warp-to-landmarks cli options

* Revert OHR to RC4_fix version

* Fix lowmem mode on OHR model

* padding to variable

* Save models in parallel threads

* Speed-up of res_block stability

* Automated Reflection Padding

* Reflect Padding as a training option

Includes auto-calculation of proper padding shapes, input_shapes, output_shapes

Flag included in config now

* rest of reflect padding

* Move TB logging to cli. Session info to state file

* Add session iterations to state file

* Add recent files to menu. GUI code tidy up

* [GUI] Fix recent file list update issue

* Add correct loss names to TensorBoard logs

* Update live graph to use TensorBoard and remove animation

* Fix analysis tab. GUI optimizations

* Analysis Graph popup to Tensorboard Logs

* [GUI] Bug fix for graphing for models with hypens in name

* [GUI] Correctly split loss to tabs during training

* [GUI] Add loss type selection to analysis graph

* Fix store command name in recent files. Switch to correct tab on open

* [GUI] Disable training graph when 'no-logs' is selected

* Fix graphing race condition

* rename original_hires model to unbalanced
2019-02-09 18:35:12 +00:00

400 lines
15 KiB
Python

#!/usr/bin python3
""" The script to run the extract process of faceswap """
import logging
import os
import sys
from pathlib import Path
from tqdm import tqdm
from lib.faces_detect import DetectedFace
from lib.gpu_stats import GPUStats
from lib.multithreading import MultiThread, PoolProcess, SpawnProcess
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import get_folder, hash_encode_image
from plugins.plugin_loader import PluginLoader
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Extract():
""" The extract process. """
def __init__(self, arguments):
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
self.args = arguments
self.output_dir = get_folder(self.args.output_dir)
logger.info("Output Directory: %s", self.args.output_dir)
self.images = Images(self.args)
self.alignments = Alignments(self.args, True, self.images.is_video)
self.plugins = Plugins(self.args)
self.post_process = PostProcess(arguments)
self.verify_output = False
self.save_interval = None
if hasattr(self.args, "save_interval"):
self.save_interval = self.args.save_interval
logger.debug("Initialized %s", self.__class__.__name__)
def process(self):
""" Perform the extraction process """
logger.info('Starting, this may take a while...')
Utils.set_verbosity(self.args.loglevel)
# queue_manager.debug_monitor(1)
self.threaded_io("load")
save_thread = self.threaded_io("save")
self.run_extraction()
save_thread.join()
self.alignments.save()
Utils.finalize(self.images.images_found,
self.alignments.faces_count,
self.verify_output)
def threaded_io(self, task, io_args=None):
""" Perform I/O task in a background thread """
logger.debug("Threading task: (Task: '%s')", task)
io_args = tuple() if io_args is None else (io_args, )
if task == "load":
func = self.load_images
elif task == "save":
func = self.save_faces
elif task == "reload":
func = self.reload_images
io_thread = MultiThread(func, *io_args, thread_count=1)
io_thread.start()
return io_thread
def load_images(self):
""" Load the images """
logger.debug("Load Images: Start")
load_queue = queue_manager.get_queue("load")
for filename, image in self.images.load():
if load_queue.shutdown.is_set():
logger.debug("Load Queue: Stop signal received. Terminating")
break
if image is None or not image.any():
logger.warning("Unable to open image. Skipping: '%s'", filename)
continue
imagename = os.path.basename(filename)
if imagename in self.alignments.data.keys():
logger.trace("Skipping image: '%s'", filename)
continue
item = {"filename": filename,
"image": image}
load_queue.put(item)
load_queue.put("EOF")
logger.debug("Load Images: Complete")
def reload_images(self, detected_faces):
""" Reload the images and pair to detected face """
logger.debug("Reload Images: Start. Detected Faces Count: %s", len(detected_faces))
load_queue = queue_manager.get_queue("detect")
for filename, image in self.images.load():
if load_queue.shutdown.is_set():
logger.debug("Reload Queue: Stop signal received. Terminating")
break
logger.trace("Reloading image: '%s'", filename)
detect_item = detected_faces.pop(filename, None)
if not detect_item:
logger.warning("Couldn't find faces for: %s", filename)
continue
detect_item["image"] = image
load_queue.put(detect_item)
load_queue.put("EOF")
logger.debug("Reload Images: Complete")
@staticmethod
def save_faces():
""" Save the generated faces """
logger.debug("Save Faces: Start")
save_queue = queue_manager.get_queue("save")
while True:
if save_queue.shutdown.is_set():
logger.debug("Save Queue: Stop signal received. Terminating")
break
item = save_queue.get()
if item == "EOF":
break
filename, face = item
logger.trace("Saving face: '%s'", filename)
try:
with open(filename, "wb") as out_file:
out_file.write(face)
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
continue
logger.debug("Save Faces: Complete")
def run_extraction(self):
""" Run Face Detection """
save_queue = queue_manager.get_queue("save")
to_process = self.process_item_count()
frame_no = 0
size = self.args.size if hasattr(self.args, "size") else 256
align_eyes = self.args.align_eyes if hasattr(self.args, "align_eyes") else False
if self.plugins.is_parallel:
logger.debug("Using parallel processing")
self.plugins.launch_aligner()
self.plugins.launch_detector()
if not self.plugins.is_parallel:
logger.debug("Using serial processing")
self.run_detection(to_process)
self.plugins.launch_aligner()
for faces in tqdm(self.plugins.detect_faces(extract_pass="align"),
total=to_process,
file=sys.stdout,
desc="Extracting faces"):
filename = faces["filename"]
self.align_face(faces, align_eyes, size, filename)
self.post_process.do_actions(faces)
faces_count = len(faces["detected_faces"])
if faces_count == 0:
logger.verbose("No faces were detected in image: %s",
os.path.basename(filename))
if not self.verify_output and faces_count > 1:
self.verify_output = True
self.output_faces(filename, faces, save_queue)
frame_no += 1
if frame_no == self.save_interval:
self.alignments.save()
frame_no = 0
save_queue.put("EOF")
def process_item_count(self):
""" Return the number of items to be processedd """
processed = sum(os.path.basename(frame) in self.alignments.data.keys()
for frame in self.images.input_images)
logger.debug("Items already processed: %s", processed)
if processed != 0 and self.args.skip_existing:
logger.info("Skipping previously extracted frames: %s", processed)
if processed != 0 and self.args.skip_faces:
logger.info("Skipping frames with detected faces: %s", processed)
to_process = self.images.images_found - processed
logger.debug("Items to be Processed: %s", to_process)
if to_process == 0:
logger.error("No frames to process. Exiting")
queue_manager.terminate_queues()
exit(0)
return to_process
def run_detection(self, to_process):
""" Run detection only """
self.plugins.launch_detector()
detected_faces = dict()
for detected in tqdm(self.plugins.detect_faces(extract_pass="detect"),
total=to_process,
file=sys.stdout,
desc="Detecting faces"):
exception = detected.get("exception", False)
if exception:
break
del detected["image"]
filename = detected["filename"]
detected_faces[filename] = detected
self.threaded_io("reload", detected_faces)
def align_face(self, faces, align_eyes, size, filename):
""" Align the detected face and add the destination file path """
final_faces = list()
image = faces["image"]
landmarks = faces["landmarks"]
detected_faces = faces["detected_faces"]
for idx, face in enumerate(detected_faces):
detected_face = DetectedFace()
detected_face.from_dlib_rect(face, image)
detected_face.landmarksXY = landmarks[idx]
detected_face.load_aligned(image, size=size, align_eyes=align_eyes)
final_faces.append({"file_location": self.output_dir / Path(filename).stem,
"face": detected_face})
faces["detected_faces"] = final_faces
def output_faces(self, filename, faces, save_queue):
""" Output faces to save thread """
final_faces = list()
for idx, detected_face in enumerate(faces["detected_faces"]):
output_file = detected_face["file_location"]
extension = Path(filename).suffix
out_filename = "{}_{}{}".format(str(output_file), str(idx), extension)
face = detected_face["face"]
resized_face = face.aligned_face
face.hash, img = hash_encode_image(resized_face, extension)
save_queue.put((out_filename, img))
final_faces.append(face.to_alignment())
self.alignments.data[os.path.basename(filename)] = final_faces
class Plugins():
""" Detector and Aligner Plugins and queues """
def __init__(self, arguments):
logger.debug("Initializing %s", self.__class__.__name__)
self.args = arguments
self.detector = self.load_detector()
self.aligner = self.load_aligner()
self.is_parallel = self.set_parallel_processing()
self.process_detect = None
self.process_align = None
self.add_queues()
logger.debug("Initialized %s", self.__class__.__name__)
def set_parallel_processing(self):
""" Set whether to run detect and align together or separately """
detector_vram = self.detector.vram
aligner_vram = self.aligner.vram
gpu_stats = GPUStats()
if (detector_vram == 0
or aligner_vram == 0
or gpu_stats.device_count == 0):
logger.debug("At least one of aligner or detector have no VRAM requirement. "
"Enabling parallel processing.")
return True
if hasattr(self.args, "multiprocess") and not self.args.multiprocess:
logger.info("NB: Parallel processing disabled.You may get faster "
"extraction speeds by enabling it with the -mp switch")
return False
required_vram = detector_vram + aligner_vram + 320 # 320MB buffer
stats = gpu_stats.get_card_most_free()
free_vram = int(stats["free"])
logger.verbose("%s - %sMB free of %sMB",
stats["device"],
free_vram,
int(stats["total"]))
if free_vram <= required_vram:
logger.warning("Not enough free VRAM for parallel processing. "
"Switching to serial")
return False
return True
def add_queues(self):
""" Add the required processing queues to Queue Manager """
for task in ("load", "detect", "align", "save"):
size = 0
if task == "load" or (not self.is_parallel and task == "detect"):
size = 100
queue_manager.add_queue(task, maxsize=size)
def load_detector(self):
""" Set global arguments and load detector plugin """
detector_name = self.args.detector.replace("-", "_").lower()
logger.debug("Loading Detector: '%s'", detector_name)
# Rotation
rotation = None
if hasattr(self.args, "rotate_images"):
rotation = self.args.rotate_images
detector = PluginLoader.get_detector(detector_name)(
loglevel=self.args.loglevel,
rotation=rotation)
return detector
def load_aligner(self):
""" Set global arguments and load aligner plugin """
aligner_name = self.args.aligner.replace("-", "_").lower()
logger.debug("Loading Aligner: '%s'", aligner_name)
aligner = PluginLoader.get_aligner(aligner_name)(
loglevel=self.args.loglevel)
return aligner
def launch_aligner(self):
""" Launch the face aligner """
logger.debug("Launching Aligner")
out_queue = queue_manager.get_queue("align")
kwargs = {"in_queue": queue_manager.get_queue("detect"),
"out_queue": out_queue}
self.process_align = SpawnProcess(self.aligner.run, **kwargs)
event = self.process_align.event
self.process_align.start()
# Wait for Aligner to take it's VRAM
# The first ever load of the model for FAN has reportedly taken
# up to 3-4 minutes, hence high timeout.
# TODO investigate why this is and fix if possible
for mins in reversed(range(5)):
event.wait(60)
if event.is_set():
break
if mins == 0:
raise ValueError("Error initializing Aligner")
logger.info("Waiting for Aligner... Time out in %s minutes", mins)
logger.debug("Launched Aligner")
def launch_detector(self):
""" Launch the face detector """
logger.debug("Launching Detector")
out_queue = queue_manager.get_queue("detect")
kwargs = {"in_queue": queue_manager.get_queue("load"),
"out_queue": out_queue}
mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
self.process_detect = mp_func(self.detector.run, **kwargs)
event = None
if hasattr(self.process_detect, "event"):
event = self.process_detect.event
self.process_detect.start()
if event is None:
logger.debug("Launched Detector")
return
for mins in reversed(range(5)):
event.wait(60)
if event.is_set():
break
if mins == 0:
raise ValueError("Error initializing Detector")
logger.info("Waiting for Detector... Time out in %s minutes", mins)
logger.debug("Launched Detector")
def detect_faces(self, extract_pass="detect"):
""" Detect faces from in an image """
logger.debug("Running Detection. Pass: '%s'", extract_pass)
if self.is_parallel or extract_pass == "align":
out_queue = queue_manager.get_queue("align")
if not self.is_parallel and extract_pass == "detect":
out_queue = queue_manager.get_queue("detect")
while True:
try:
faces = out_queue.get(True, 1)
if faces == "EOF":
break
if isinstance(faces, dict) and faces.get("exception"):
pid = faces["exception"][0]
t_back = faces["exception"][1].getvalue()
err = "Error in child process {}. {}".format(pid, t_back)
raise Exception(err)
except QueueEmpty:
continue
yield faces
logger.debug("Detection Complete")