1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/scripts/fsmedia.py

371 lines
15 KiB
Python

#!/usr/bin/env python3
""" Holds the classes for the 3 main Faceswap 'media' objects for
input (extract) and output (convert) tasks. Those being:
Images
Faces
Alignments"""
import logging
import os
from pathlib import Path
import cv2
import imageio
from lib.alignments import Alignments as AlignmentsBase
from lib.face_filter import FaceFilter as FilterFunc
from lib.image import count_frames, read_image
from lib.utils import (camel_case_split, get_image_paths, set_system_verbosity, _video_extensions)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Utils():
""" Holds utility functions that are required by more than one media
object """
@staticmethod
def set_verbosity(loglevel):
""" Set the system output verbosity """
set_system_verbosity(loglevel)
@staticmethod
def finalize(images_found, num_faces_detected, verify_output):
""" Finalize the image processing """
logger.info("-------------------------")
logger.info("Images found: %s", images_found)
logger.info("Faces detected: %s", num_faces_detected)
logger.info("-------------------------")
if verify_output:
logger.info("Note:")
logger.info("Multiple faces were detected in one or more pictures.")
logger.info("Double check your results.")
logger.info("-------------------------")
logger.info("Process Succesfully Completed. Shutting Down...")
class Alignments(AlignmentsBase):
""" Override main alignments class for extract """
def __init__(self, arguments, is_extract, input_is_video=False):
logger.debug("Initializing %s: (is_extract: %s, input_is_video: %s)",
self.__class__.__name__, is_extract, input_is_video)
self.args = arguments
self.is_extract = is_extract
folder, filename = self.set_folder_filename(input_is_video)
super().__init__(folder,
filename=filename)
logger.debug("Initialized %s", self.__class__.__name__)
def set_folder_filename(self, input_is_video):
""" Return the folder for the alignments file"""
if self.args.alignments_path:
logger.debug("Alignments File provided: '%s'", self.args.alignments_path)
folder, filename = os.path.split(str(self.args.alignments_path))
elif input_is_video:
logger.debug("Alignments from Video File: '%s'", self.args.input_dir)
folder, filename = os.path.split(self.args.input_dir)
filename = "{}_alignments".format(os.path.splitext(filename)[0])
else:
logger.debug("Alignments from Input Folder: '%s'", self.args.input_dir)
folder = str(self.args.input_dir)
filename = "alignments"
logger.debug("Setting Alignments: (folder: '%s' filename: '%s')", folder, filename)
return folder, filename
def load(self):
""" Override parent loader to handle skip existing on extract """
data = dict()
if not self.is_extract:
if not self.have_alignments_file:
return data
data = super().load()
return data
skip_existing = bool(hasattr(self.args, 'skip_existing')
and self.args.skip_existing)
skip_faces = bool(hasattr(self.args, 'skip_faces')
and self.args.skip_faces)
if not skip_existing and not skip_faces:
logger.debug("No skipping selected. Returning empty dictionary")
return data
if not self.have_alignments_file and (skip_existing or skip_faces):
logger.warning("Skip Existing/Skip Faces selected, but no alignments file found!")
return data
data = self.serializer.load(self.file)
if skip_faces:
# Remove items from algnments that have no faces so they will
# be re-detected
del_keys = [key for key, val in data.items() if not val]
logger.debug("Frames with no faces selected for redetection: %s", len(del_keys))
for key in del_keys:
if key in data:
logger.trace("Selected for redetection: '%s'", key)
del data[key]
return data
class Images():
""" Holds the full frames/images """
def __init__(self, arguments):
logger.debug("Initializing %s", self.__class__.__name__)
self.args = arguments
self.is_video = self.check_input_folder()
self.input_images = self.get_input_images()
self.images_found = self.count_images()
logger.debug("Initialized %s", self.__class__.__name__)
def count_images(self):
""" Number of images or frames """
if self.is_video:
retval = int(count_frames(self.args.input_dir, fast=True))
else:
retval = len(self.input_images)
return retval
def check_input_folder(self):
""" Check whether the input is a folder or video """
if not os.path.exists(self.args.input_dir):
logger.error("Input location %s not found.", self.args.input_dir)
exit(1)
if (os.path.isfile(self.args.input_dir) and
os.path.splitext(self.args.input_dir)[1].lower() in _video_extensions):
logger.info("Input Video: %s", self.args.input_dir)
retval = True
else:
logger.info("Input Directory: %s", self.args.input_dir)
retval = False
return retval
def get_input_images(self):
""" Return the list of images or video file that is to be processed """
if self.is_video:
input_images = self.args.input_dir
else:
input_images = get_image_paths(self.args.input_dir)
return input_images
def load(self):
""" Load an image and yield it with it's filename """
iterator = self.load_video_frames if self.is_video else self.load_disk_frames
for filename, image in iterator():
yield filename, image
def load_disk_frames(self):
""" Load frames from disk """
logger.debug("Input is separate Frames. Loading images")
for filename in self.input_images:
image = read_image(filename, raise_error=False)
if image is None:
continue
yield filename, image
def load_video_frames(self):
""" Return frames from a video file """
logger.debug("Input is video. Capturing frames")
vidname = os.path.splitext(os.path.basename(self.args.input_dir))[0]
reader = imageio.get_reader(self.args.input_dir, "ffmpeg")
for i, frame in enumerate(reader):
# Convert to BGR for cv2 compatibility
frame = frame[:, :, ::-1]
filename = "{}_{:06d}.png".format(vidname, i + 1)
logger.trace("Loading video frame: '%s'", filename)
yield filename, frame
reader.close()
def load_one_image(self, filename):
""" load requested image """
logger.trace("Loading image: '%s'", filename)
if self.is_video:
if filename.isdigit():
frame_no = filename
else:
frame_no = os.path.splitext(filename)[0][filename.rfind("_") + 1:]
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
retval = self.load_one_video_frame(int(frame_no))
else:
retval = read_image(filename, raise_error=True)
return retval
def load_one_video_frame(self, frame_no):
""" Load a single frame from a video file """
logger.trace("Loading video frame: %s", frame_no)
reader = imageio.get_reader(self.args.input_dir, "ffmpeg")
reader.set_image_index(frame_no - 1)
frame = reader.get_next_data()[:, :, ::-1]
reader.close()
return frame
class PostProcess():
""" Optional post processing tasks """
def __init__(self, arguments):
logger.debug("Initializing %s", self.__class__.__name__)
self.args = arguments
self.actions = self.set_actions()
logger.debug("Initialized %s", self.__class__.__name__)
def set_actions(self):
""" Compile the actions to be performed into a list """
postprocess_items = self.get_items()
actions = list()
for action, options in postprocess_items.items():
options = dict() if options is None else options
args = options.get("args", tuple())
kwargs = options.get("kwargs", dict())
args = args if isinstance(args, tuple) else tuple()
kwargs = kwargs if isinstance(kwargs, dict) else dict()
task = globals()[action](*args, **kwargs)
if task.valid:
logger.debug("Adding Postprocess action: '%s'", task)
actions.append(task)
for action in actions:
action_name = camel_case_split(action.__class__.__name__)
logger.info("Adding post processing item: %s", " ".join(action_name))
return actions
def get_items(self):
""" Set the post processing actions """
postprocess_items = dict()
# Debug Landmarks
if (hasattr(self.args, 'debug_landmarks')
and self.args.debug_landmarks):
postprocess_items["DebugLandmarks"] = None
# Face Filter post processing
if ((hasattr(self.args, "filter") and self.args.filter is not None) or
(hasattr(self.args, "nfilter") and
self.args.nfilter is not None)):
if hasattr(self.args, "detector"):
detector = self.args.detector.replace("-", "_").lower()
else:
detector = "cv2_dnn"
if hasattr(self.args, "aligner"):
aligner = self.args.aligner.replace("-", "_").lower()
else:
aligner = "cv2_dnn"
face_filter = dict(detector=detector,
aligner=aligner,
multiprocess=not self.args.singleprocess)
filter_lists = dict()
if hasattr(self.args, "ref_threshold"):
face_filter["ref_threshold"] = self.args.ref_threshold
for filter_type in ('filter', 'nfilter'):
filter_args = getattr(self.args, filter_type, None)
filter_args = None if not filter_args else filter_args
filter_lists[filter_type] = filter_args
face_filter["filter_lists"] = filter_lists
postprocess_items["FaceFilter"] = {"kwargs": face_filter}
logger.debug("Postprocess Items: %s", postprocess_items)
return postprocess_items
def do_actions(self, output_item):
""" Perform the requested post-processing actions """
for action in self.actions:
logger.debug("Performing postprocess action: '%s'", action.__class__.__name__)
action.process(output_item)
class PostProcessAction(): # pylint: disable=too-few-public-methods
""" Parent class for Post Processing Actions
Usuable in Extract or Convert or both
depending on context """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs)
self.valid = True # Set to False if invalid params passed in to disable
logger.debug("Initialized base class %s", self.__class__.__name__)
def process(self, output_item):
""" Override for specific post processing action """
raise NotImplementedError
class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-methods
""" Draw debug landmarks on face
Extract Only """
def process(self, output_item):
""" Draw landmarks on image """
frame = os.path.splitext(os.path.basename(output_item["filename"]))[0]
for idx, face in enumerate(output_item["detected_faces"]):
logger.trace("Drawing Landmarks. Frame: '%s'. Face: %s", frame, idx)
aligned_landmarks = face.aligned_landmarks
for (pos_x, pos_y) in aligned_landmarks:
cv2.circle(face.aligned_face, (pos_x, pos_y), 2, (0, 0, 255), -1)
class FaceFilter(PostProcessAction):
""" Filter in or out faces based on input image(s)
Extract or Convert """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logger.info("Extracting and aligning face for Face Filter...")
self.filter = self.load_face_filter(**kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def load_face_filter(self, filter_lists, ref_threshold, aligner, detector,
multiprocess):
""" Load faces to filter out of images """
if not any(val for val in filter_lists.values()):
return None
facefilter = None
filter_files = [self.set_face_filter(f_type, filter_lists[f_type])
for f_type in ("filter", "nfilter")]
if any(filters for filters in filter_files):
facefilter = FilterFunc(filter_files[0],
filter_files[1],
detector,
aligner,
multiprocess,
ref_threshold)
logger.debug("Face filter: %s", facefilter)
else:
self.valid = False
return facefilter
@staticmethod
def set_face_filter(f_type, f_args):
""" Set the required filters """
if not f_args:
return list()
logger.info("%s: %s", f_type.title(), f_args)
filter_files = f_args if isinstance(f_args, list) else [f_args]
filter_files = list(filter(lambda fpath: Path(fpath).exists(), filter_files))
if not filter_files:
logger.warning("Face %s files were requested, but no files could be found. This "
"filter will not be applied.", f_type)
logger.debug("Face Filter files: %s", filter_files)
return filter_files
def process(self, output_item):
""" Filter in/out wanted/unwanted faces """
if not self.filter:
return
ret_faces = list()
for idx, detect_face in enumerate(output_item["detected_faces"]):
check_item = detect_face["face"] if isinstance(detect_face, dict) else detect_face
check_item.load_aligned(output_item["image"])
if not self.filter.check(check_item):
logger.verbose("Skipping not recognized face: (Frame: %s Face %s)",
output_item["filename"], idx)
continue
logger.trace("Accepting recognised face. Frame: %s. Face: %s",
output_item["filename"], idx)
ret_faces.append(detect_face)
output_item["detected_faces"] = ret_faces