1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/scripts/extract.py
torzdf 66ed005ef3
Optimize Data Augmentation (#881)
* Move image utils to lib.image
* Add .pylintrc file
* Remove some cv2 pylint ignores
* TrainingData: Load images from disk in batches
* TrainingData: get_landmarks to batch
* TrainingData: transform and flip to batches
* TrainingData: Optimize color augmentation
* TrainingData: Optimize target and random_warp
* TrainingData - Convert _get_closest_match for batching
* TrainingData: Warp To Landmarks optimized
* Save models to threadpoolexecutor
* Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings
* Masks: Set dtype and threshold for lib.masks based on input face
* Docstrings and Documentation
2019-09-24 12:16:05 +01:00

262 lines
11 KiB
Python

#!/usr/bin python3
""" The script to run the extract process of faceswap """
import logging
import os
import sys
from pathlib import Path
from tqdm import tqdm
from lib.image import encode_image_with_hash
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import get_folder, deprecation_warning
from plugins.extract.pipeline import Extractor
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Extract():
""" The extract process. """
def __init__(self, arguments):
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
self.args = arguments
Utils.set_verbosity(self.args.loglevel)
self.output_dir = get_folder(self.args.output_dir)
logger.info("Output Directory: %s", self.args.output_dir)
self.images = Images(self.args)
self.alignments = Alignments(self.args, True, self.images.is_video)
self.post_process = PostProcess(arguments)
configfile = self.args.configfile if hasattr(self.args, "configfile") else None
normalization = None if self.args.normalization == "none" else self.args.normalization
self.extractor = Extractor(self.args.detector,
self.args.aligner,
configfile=configfile,
multiprocess=not self.args.singleprocess,
rotate_images=self.args.rotate_images,
min_size=self.args.min_size,
normalize_method=normalization)
self.save_queue = queue_manager.get_queue("extract_save")
self.threads = list()
self.verify_output = False
self.save_interval = None
if hasattr(self.args, "save_interval"):
self.save_interval = self.args.save_interval
logger.debug("Initialized %s", self.__class__.__name__)
@property
def skip_num(self):
""" Number of frames to skip if extract_every_n is passed """
return self.args.extract_every_n if hasattr(self.args, "extract_every_n") else 1
def process(self):
""" Perform the extraction process """
logger.info('Starting, this may take a while...')
# queue_manager.debug_monitor(3)
self.threaded_io("load")
self.threaded_io("save")
self.run_extraction()
for thread in self.threads:
thread.join()
self.alignments.save()
Utils.finalize(self.images.images_found // self.skip_num,
self.alignments.faces_count,
self.verify_output)
def threaded_io(self, task, io_args=None):
""" Perform I/O task in a background thread """
logger.debug("Threading task: (Task: '%s')", task)
io_args = tuple() if io_args is None else (io_args, )
if task == "load":
func = self.load_images
elif task == "save":
func = self.save_faces
elif task == "reload":
func = self.reload_images
io_thread = MultiThread(func, *io_args, thread_count=1)
io_thread.start()
self.threads.append(io_thread)
def load_images(self):
""" Load the images """
logger.debug("Load Images: Start")
load_queue = self.extractor.input_queue
idx = 0
for filename, image in self.images.load():
idx += 1
if load_queue.shutdown.is_set():
logger.debug("Load Queue: Stop signal received. Terminating")
break
if idx % self.skip_num != 0:
logger.trace("Skipping image '%s' due to extract_every_n = %s",
filename, self.skip_num)
continue
if image is None or (not image.any() and image.ndim not in ((2, 3))):
# All black frames will return not np.any() so check dims too
logger.warning("Unable to open image. Skipping: '%s'", filename)
continue
imagename = os.path.basename(filename)
if imagename in self.alignments.data.keys():
logger.trace("Skipping image: '%s'", filename)
continue
item = {"filename": filename,
"image": image}
load_queue.put(item)
load_queue.put("EOF")
logger.debug("Load Images: Complete")
def reload_images(self, detected_faces):
""" Reload the images and pair to detected face """
logger.debug("Reload Images: Start. Detected Faces Count: %s", len(detected_faces))
load_queue = self.extractor.input_queue
idx = 0
for filename, image in self.images.load():
idx += 1
if load_queue.shutdown.is_set():
logger.debug("Reload Queue: Stop signal received. Terminating")
break
if idx % self.skip_num != 0:
logger.trace("Skipping image '%s' due to extract_every_n = %s",
filename, self.skip_num)
continue
logger.trace("Reloading image: '%s'", filename)
detect_item = detected_faces.pop(filename, None)
if not detect_item:
logger.warning("Couldn't find faces for: %s", filename)
continue
detect_item["image"] = image
load_queue.put(detect_item)
load_queue.put("EOF")
logger.debug("Reload Images: Complete")
def save_faces(self):
""" Save the generated faces """
logger.debug("Save Faces: Start")
while True:
if self.save_queue.shutdown.is_set():
logger.debug("Save Queue: Stop signal received. Terminating")
break
item = self.save_queue.get()
logger.trace(item)
if item == "EOF":
break
filename, face = item
logger.trace("Saving face: '%s'", filename)
try:
with open(filename, "wb") as out_file:
out_file.write(face)
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
continue
logger.debug("Save Faces: Complete")
def process_item_count(self):
""" Return the number of items to be processedd """
processed = sum(os.path.basename(frame) in self.alignments.data.keys()
for frame in self.images.input_images)
logger.debug("Items already processed: %s", processed)
if processed != 0 and self.args.skip_existing:
logger.info("Skipping previously extracted frames: %s", processed)
if processed != 0 and self.args.skip_faces:
logger.info("Skipping frames with detected faces: %s", processed)
to_process = (self.images.images_found - processed) // self.skip_num
logger.debug("Items to be Processed: %s", to_process)
if to_process == 0:
logger.error("No frames to process. Exiting")
queue_manager.terminate_queues()
exit(0)
return to_process
def run_extraction(self):
""" Run Face Detection """
to_process = self.process_item_count()
size = self.args.size if hasattr(self.args, "size") else 256
exception = False
for phase in range(self.extractor.passes):
if exception:
break
is_final = self.extractor.final_pass
detected_faces = dict()
self.extractor.launch()
self.check_thread_error()
desc = "Running pass {} of {}: {}".format(phase + 1,
self.extractor.passes,
self.extractor.phase.title())
status_bar = tqdm(self.extractor.detected_faces(),
total=to_process,
file=sys.stdout,
desc=desc)
for idx, faces in enumerate(status_bar):
self.check_thread_error()
exception = faces.get("exception", False)
if exception:
break
filename = faces["filename"]
if self.extractor.final_pass:
self.output_processing(faces, size, filename)
self.output_faces(filename, faces)
if self.save_interval and (idx + 1) % self.save_interval == 0:
self.alignments.save()
else:
del faces["image"]
detected_faces[filename] = faces
status_bar.update(1)
if is_final:
logger.debug("Putting EOF to save")
self.save_queue.put("EOF")
else:
logger.debug("Reloading images")
self.threaded_io("reload", detected_faces)
def check_thread_error(self):
""" Check and raise thread errors """
for thread in self.threads:
thread.check_and_raise_error()
def output_processing(self, faces, size, filename):
""" Prepare faces for output """
self.align_face(faces, size, filename)
self.post_process.do_actions(faces)
faces_count = len(faces["detected_faces"])
if faces_count == 0:
logger.verbose("No faces were detected in image: %s", os.path.basename(filename))
if not self.verify_output and faces_count > 1:
self.verify_output = True
def align_face(self, faces, size, filename):
""" Align the detected face and add the destination file path """
final_faces = list()
image = faces["image"]
detected_faces = faces["detected_faces"]
for face in detected_faces:
face.load_aligned(image, size=size)
final_faces.append({"file_location": self.output_dir / Path(filename).stem,
"face": face})
faces["detected_faces"] = final_faces
def output_faces(self, filename, faces):
""" Output faces to save thread """
final_faces = list()
for idx, detected_face in enumerate(faces["detected_faces"]):
output_file = detected_face["file_location"]
extension = Path(filename).suffix
out_filename = "{}_{}{}".format(str(output_file), str(idx), extension)
face = detected_face["face"]
resized_face = face.aligned_face
face.hash, img = encode_image_with_hash(resized_face, extension)
self.save_queue.put((out_filename, img))
final_faces.append(face.to_alignment())
self.alignments.data[os.path.basename(filename)] = final_faces