1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 11:53:26 -04:00
faceswap/scripts/extract.py
2018-11-19 13:18:06 +00:00

378 lines
13 KiB
Python

#!/usr/bin python3
""" The script to run the extract process of faceswap """
import os
import sys
from pathlib import Path
import cv2
from tqdm import tqdm
from lib.faces_detect import DetectedFace
from lib.gpu_stats import GPUStats
from lib.multithreading import MultiThread, PoolProcess, SpawnProcess
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import get_folder
from plugins.plugin_loader import PluginLoader
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning
class Extract():
""" The extract process. """
def __init__(self, arguments):
self.args = arguments
self.output_dir = get_folder(self.args.output_dir)
print("Output Directory: {}".format(self.args.output_dir))
self.images = Images(self.args)
self.alignments = Alignments(self.args, True)
self.plugins = Plugins(self.args)
self.post_process = PostProcess(arguments)
self.export_face = True
self.verify_output = False
self.save_interval = None
if hasattr(self.args, "save_interval"):
self.save_interval = self.args.save_interval
def process(self):
""" Perform the extraction process """
print('Starting, this may take a while...')
Utils.set_verbosity(self.args.verbose)
# queue_manager.debug_monitor(1)
self.threaded_io("load")
save_thread = self.threaded_io("save")
self.run_extraction(save_thread)
self.alignments.save()
Utils.finalize(self.images.images_found,
self.alignments.faces_count,
self.verify_output)
def threaded_io(self, task, io_args=None):
""" Load images in a background thread """
io_args = tuple() if io_args is None else (io_args, )
if task == "load":
func = self.load_images
elif task == "save":
func = self.save_faces
elif task == "reload":
func = self.reload_images
io_thread = MultiThread(thread_count=1)
io_thread.in_thread(func, *io_args)
return io_thread
def load_images(self):
""" Load the images """
load_queue = queue_manager.get_queue("load")
for filename, image in self.images.load():
imagename = os.path.basename(filename)
if imagename in self.alignments.data.keys():
continue
load_queue.put((filename, image))
load_queue.put("EOF")
def reload_images(self, detected_faces):
""" Reload the images and pair to detected face """
load_queue = queue_manager.get_queue("detect")
for filename, image in self.images.load():
detect_item = detected_faces.pop(filename, None)
if not detect_item:
continue
detect_item["image"] = image
load_queue.put(detect_item)
load_queue.put("EOF")
def save_faces(self):
""" Save the generated faces """
if not self.export_face:
return
save_queue = queue_manager.get_queue("save")
while True:
item = save_queue.get()
if item == "EOF":
break
filename, output_file, resized_face, idx = item
out_filename = "{}_{}{}".format(str(output_file),
str(idx),
Path(filename).suffix)
# pylint: disable=no-member
cv2.imwrite(out_filename, resized_face)
def run_extraction(self, save_thread):
""" Run Face Detection """
to_process = self.process_item_count()
frame_no = 0
if self.plugins.is_parallel:
self.plugins.launch_aligner()
self.plugins.launch_detector()
if not self.plugins.is_parallel:
self.run_detection(to_process)
self.plugins.launch_aligner()
for faces in tqdm(self.plugins.detect_faces(extract_pass="align"),
total=to_process,
file=sys.stdout,
desc="Extracting faces"):
exception = faces.get("exception", False)
if exception:
exit(1)
filename = faces["filename"]
faces["output_file"] = self.output_dir / Path(filename).stem
self.post_process.do_actions(faces)
faces_count = len(faces["detected_faces"])
if self.args.verbose and faces_count == 0:
print("Warning: No faces were detected in image: "
"{}".format(os.path.basename(filename)))
if not self.verify_output and faces_count > 1:
self.verify_output = True
self.process_faces(filename, faces)
frame_no += 1
if frame_no == self.save_interval:
self.alignments.save()
frame_no = 0
if self.export_face:
queue_manager.get_queue("save").put("EOF")
save_thread.join_threads()
def process_item_count(self):
""" Return the number of items to be processedd """
processed = sum(os.path.basename(frame) in self.alignments.data.keys()
for frame in self.images.input_images)
if processed != 0 and self.args.skip_existing:
print("Skipping {} previously extracted frames".format(processed))
if processed != 0 and self.args.skip_faces:
print("Skipping {} frames with detected faces".format(processed))
to_process = self.images.images_found - processed
if to_process == 0:
print("No frames to process. Exiting")
queue_manager.terminate_queues()
exit(0)
return to_process
def run_detection(self, to_process):
""" Run detection only """
self.plugins.launch_detector()
detected_faces = dict()
for detected in tqdm(self.plugins.detect_faces(extract_pass="detect"),
total=to_process,
file=sys.stdout,
desc="Detecting faces"):
exception = detected.get("exception", False)
if exception:
break
del detected["image"]
filename = detected["filename"]
detected_faces[filename] = detected
self.threaded_io("reload", detected_faces)
def process_faces(self, filename, faces):
""" Perform processing on found faces """
final_faces = list()
save_queue = queue_manager.get_queue("save")
filename = faces["filename"]
output_file = faces["output_file"]
for idx, face in enumerate(faces["detected_faces"]):
if self.export_face:
save_queue.put((filename,
output_file,
face.aligned_face,
idx))
final_faces.append(face.to_alignment())
self.alignments.data[os.path.basename(filename)] = final_faces
class Plugins():
""" Detector and Aligner Plugins and queues """
def __init__(self, arguments):
self.args = arguments
self.detector = self.load_detector()
self.aligner = self.load_aligner()
self.is_parallel = self.set_parallel_processing()
self.add_queues()
def set_parallel_processing(self):
""" Set whether to run detect and align together or seperately """
detector_vram = self.detector.vram
aligner_vram = self.aligner.vram
gpu_stats = GPUStats()
if (detector_vram == 0
or aligner_vram == 0
or gpu_stats.device_count == 0):
return True
if hasattr(self.args, "multiprocess") and not self.args.multiprocess:
print("\nNB: Parallel processing disabled.\nYou may get faster "
"extraction speeds by enabling it with the -mp switch\n")
return False
required_vram = detector_vram + aligner_vram + 320 # 320MB buffer
stats = gpu_stats.get_card_most_free()
free_vram = int(stats["free"])
if self.args.verbose:
print("{} - {}MB free of {}MB".format(stats["device"],
free_vram,
int(stats["total"])))
if free_vram <= required_vram:
if self.args.verbose:
print("Not enough free VRAM for parallel processing. "
"Switching to serial")
return False
return True
def add_queues(self):
""" Add the required processing queues to Queue Manager """
for task in ("load", "detect", "align", "save"):
size = 0
if task == "load" or (not self.is_parallel and task == "detect"):
size = 100
queue_manager.add_queue(task, maxsize=size)
def load_detector(self):
""" Set global arguments and load detector plugin """
detector_name = self.args.detector.replace("-", "_").lower()
# Rotation
rotation = None
if hasattr(self.args, "rotate_images"):
rotation = self.args.rotate_images
detector = PluginLoader.get_detector(detector_name)(
verbose=self.args.verbose,
rotation=rotation)
return detector
def load_aligner(self):
""" Set global arguments and load aligner plugin """
aligner_name = self.args.aligner.replace("-", "_").lower()
# Align Eyes
align_eyes = False
if hasattr(self.args, 'align_eyes'):
align_eyes = self.args.align_eyes
# Extracted Face Size
size = 256
if hasattr(self.args, 'size'):
size = self.args.size
aligner = PluginLoader.get_aligner(aligner_name)(
verbose=self.args.verbose,
align_eyes=align_eyes,
size=size)
return aligner
def launch_aligner(self):
""" Launch the face aligner """
out_queue = queue_manager.get_queue("align")
kwargs = {"in_queue": queue_manager.get_queue("detect"),
"out_queue": out_queue}
align_process = SpawnProcess()
event = align_process.event
align_process.in_process(self.aligner.align, **kwargs)
# Wait for Aligner to take it's VRAM
# The first ever load of the model for FAN has reportedly taken
# up to 3-4 minutes, hence high timeout.
# TODO investigate why this is and fix if possible
event.wait(300)
if not event.is_set():
raise ValueError("Error inititalizing Aligner")
try:
err = None
err = out_queue.get(True, 1)
except QueueEmpty:
pass
if err:
if isinstance(err, str):
queue_manager.terminate_queues()
print(err)
exit(1)
else:
queue_manager.get_queue("detect").put(err)
def launch_detector(self):
""" Launch the face detector """
out_queue = queue_manager.get_queue("detect")
kwargs = {"in_queue": queue_manager.get_queue("load"),
"out_queue": out_queue,
"detected_face": DetectedFace()} # Passed in to avoid race condition
if self.args.detector == "mtcnn":
mtcnn_kwargs = self.detector.validate_kwargs(
self.get_mtcnn_kwargs())
kwargs["mtcnn_kwargs"] = mtcnn_kwargs
if self.detector.parent_is_pool:
detect_process = PoolProcess(self.detector.detect_faces)
else:
detect_process = SpawnProcess()
event = None
if hasattr(detect_process, "event"):
event = detect_process.event
detect_process.in_process(self.detector.detect_faces, **kwargs)
if not event:
return
event.wait(60)
if not event.is_set():
raise ValueError("Error inititalizing Detector")
def get_mtcnn_kwargs(self):
""" Add the mtcnn arguments into a kwargs dictionary """
mtcnn_threshold = [float(thr.strip())
for thr in self.args.mtcnn_threshold]
return {"minsize": self.args.mtcnn_minsize,
"threshold": mtcnn_threshold,
"factor": self.args.mtcnn_scalefactor}
def detect_faces(self, extract_pass="detect"):
""" Detect faces from in an image """
if self.is_parallel or extract_pass == "align":
out_queue = queue_manager.get_queue("align")
if not self.is_parallel and extract_pass == "detect":
out_queue = queue_manager.get_queue("detect")
while True:
try:
faces = out_queue.get(True, 1)
if faces == "EOF":
break
exception = faces.get("exception", None)
if exception is not None:
queue_manager.terminate_queues()
yield faces
break
except QueueEmpty:
continue
yield faces