1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 11:53:26 -04:00
faceswap/plugins/extract/detect/_base.py

341 lines
14 KiB
Python
Executable file

#!/usr/bin/env python3
""" Base class for Face Detector plugins
Plugins should inherit from this class
See the override methods for which methods are
required.
For each source frame, the plugin must pass a dict to finalize containing:
{"filename": <filename of source frame>,
"image": <source image>,
"detected_faces": <list of BoundingBoxes>} (Class defined in /lib/faces_detect)
"""
import logging
import os
import traceback
from io import StringIO
import cv2
from lib.faces_detect import BoundingBox
from lib.gpu_stats import GPUStats
from lib.utils import rotate_landmarks, GetModel
from plugins.extract._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name, configfile=None):
""" Return the config for the requested model """
return Config(plugin_name, configfile=configfile).config_dict
class Detector():
""" Detector object """
def __init__(self, loglevel, configfile=None,
git_model_id=None, model_filename=None, rotation=None, min_size=0):
logger.debug("Initializing %s: (loglevel: %s, configfile: %s, git_model_id: %s, "
"model_filename: %s, rotation: %s, min_size: %s)",
self.__class__.__name__, loglevel, configfile, git_model_id,
model_filename, rotation, min_size)
self.config = get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
self.loglevel = loglevel
self.rotation = self.get_rotation_angles(rotation)
self.min_size = min_size
self.parent_is_pool = False
self.init = None
self.error = None
# The input and output queues for the plugin.
# See lib.queue_manager.QueueManager for getting queues
self.queues = {"in": None, "out": None}
# Path to model if required
self.model_path = self.get_model(git_model_id, model_filename)
# Target image size for passing images through the detector
# Set to tuple of dimensions (x, y) or int of pixel count
self.target = None
# Approximate VRAM used for the set target. Used to calculate
# how many parallel processes / batches can be run.
# Be conservative to avoid OOM.
self.vram = None
# Set to true if the plugin supports PlaidML
self.supports_plaidml = False
# For detectors that support batching, this should be set to
# the calculated batch size that the amount of available VRAM
# will support. It is also used for holding the number of threads/
# processes for parallel processing plugins
self.batch_size = 1
logger.debug("Initialized _base %s", self.__class__.__name__)
# <<< OVERRIDE METHODS >>> #
def initialize(self, *args, **kwargs):
""" Inititalize the detector
Tasks to be run before any detection is performed.
Override for specific detector """
logger.debug("initialize %s (PID: %s, args: %s, kwargs: %s)",
self.__class__.__name__, os.getpid(), args, kwargs)
self.init = kwargs.get("event", False)
self.error = kwargs.get("error", False)
self.queues["in"] = kwargs["in_queue"]
self.queues["out"] = kwargs["out_queue"]
def detect_faces(self, *args, **kwargs):
""" Detect faces in rgb image
Override for specific detector
Must return a list of BoundingBox's"""
try:
if not self.init:
self.initialize(*args, **kwargs)
except ValueError as err:
logger.error(err)
exit(1)
logger.debug("Detecting Faces (args: %s, kwargs: %s)", args, kwargs)
# <<< GET MODEL >>> #
@staticmethod
def get_model(git_model_id, model_filename):
""" Check if model is available, if not, download and unzip it """
if model_filename is None:
logger.debug("No model_filename specified. Returning None")
return None
if git_model_id is None:
logger.debug("No git_model_id specified. Returning None")
return None
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
model = GetModel(model_filename, cache_path, git_model_id)
return model.model_path
# <<< DETECTION WRAPPER >>> #
def run(self, *args, **kwargs):
""" Parent detect process.
This should always be called as the entry point so exceptions
are passed back to parent.
Do not override """
try:
logger.debug("Executing detector run function")
self.detect_faces(*args, **kwargs)
except Exception as err: # pylint: disable=broad-except
logger.error("Caught exception in child process: %s: %s", os.getpid(), str(err))
# Display traceback if in initialization stage
if not self.init.is_set():
logger.exception("Traceback:")
tb_buffer = StringIO()
traceback.print_exc(file=tb_buffer)
logger.trace(tb_buffer.getvalue())
exception = {"exception": (os.getpid(), tb_buffer)}
self.queues["out"].put(exception)
exit(1)
# <<< FINALIZE METHODS>>> #
def finalize(self, output):
""" This should be called as the final task of each plugin
Performs fianl processing and puts to the out queue """
if isinstance(output, dict):
logger.trace("Item out: %s", {key: val
for key, val in output.items()
if key != "image"})
if self.min_size > 0 and output.get("detected_faces", None):
output["detected_faces"] = self.filter_small_faces(output["detected_faces"])
else:
logger.trace("Item out: %s", output)
self.queues["out"].put(output)
def filter_small_faces(self, detected_faces):
""" Filter out any faces smaller than the min size threshold """
retval = list()
for face in detected_faces:
face_size = (face.width ** 2 + face.height ** 2) ** 0.5
if face_size < self.min_size:
logger.debug("Removing detected face: (face_size: %s, min_size: %s",
face_size, self.min_size)
continue
retval.append(face)
return retval
# <<< DETECTION IMAGE COMPILATION METHODS >>> #
def compile_detection_image(self, input_image,
is_square=False, scale_up=False, to_rgb=False, to_grayscale=False):
""" Compile the detection image """
image = input_image.copy()
if to_rgb:
image = image[:, :, ::-1]
elif to_grayscale:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # pylint: disable=no-member
scale = self.set_scale(image, is_square=is_square, scale_up=scale_up)
image = self.scale_image(image, scale)
return [image, scale]
def set_scale(self, image, is_square=False, scale_up=False):
""" Set the scale factor for incoming image """
height, width = image.shape[:2]
if is_square:
if isinstance(self.target, int):
dims = (self.target ** 0.5, self.target ** 0.5)
self.target = dims
source = max(height, width)
target = max(self.target)
else:
source = (width * height) ** 0.5
if isinstance(self.target, tuple):
self.target = self.target[0] * self.target[1]
target = self.target ** 0.5
if scale_up or target < source:
scale = target / source
else:
scale = 1.0
logger.trace("Detector scale: %s", scale)
return scale
@staticmethod
def scale_image(image, scale):
""" Scale the image """
# pylint: disable=no-member
if scale == 1.0:
return image
height, width = image.shape[:2]
interpln = cv2.INTER_LINEAR if scale > 1.0 else cv2.INTER_AREA
dims = (int(width * scale), int(height * scale))
if scale < 1.0:
logger.trace("Resizing image from %sx%s to %s.",
width, height, "x".join(str(i) for i in dims))
image = cv2.resize(image, dims, interpolation=interpln)
return image
# <<< IMAGE ROTATION METHODS >>> #
@staticmethod
def get_rotation_angles(rotation):
""" Set the rotation angles. Includes backwards compatibility for the
'on' and 'off' options:
- 'on' - increment 90 degrees
- 'off' - disable
- 0 is prepended to the list, as whatever happens, we want to
scan the image in it's upright state """
rotation_angles = [0]
if not rotation or rotation.lower() == "off":
logger.debug("Not setting rotation angles")
return rotation_angles
if rotation.lower() == "on":
rotation_angles.extend(range(90, 360, 90))
else:
passed_angles = [int(angle)
for angle in rotation.split(",")]
if len(passed_angles) == 1:
rotation_step_size = passed_angles[0]
rotation_angles.extend(range(rotation_step_size,
360,
rotation_step_size))
elif len(passed_angles) > 1:
rotation_angles.extend(passed_angles)
logger.debug("Rotation Angles: %s", rotation_angles)
return rotation_angles
def rotate_image(self, image, angle):
""" Rotate the image by given angle and return
Image with rotation matrix """
if angle == 0:
return image, None
return self.rotate_image_by_angle(image, angle)
@staticmethod
def rotate_rect(bounding_box, rotation_matrix):
""" Rotate a BoundingBox based on the rotation_matrix"""
logger.trace("Rotating BoundingBox")
bounding_box = rotate_landmarks(bounding_box, rotation_matrix)
return bounding_box
@staticmethod
def rotate_image_by_angle(image, angle,
rotated_width=None, rotated_height=None):
""" Rotate an image by a given angle.
From: https://stackoverflow.com/questions/22041699 """
logger.trace("Rotating image: (angle: %s, rotated_width: %s, rotated_height: %s)",
angle, rotated_width, rotated_height)
height, width = image.shape[:2]
image_center = (width/2, height/2)
rotation_matrix = cv2.getRotationMatrix2D( # pylint: disable=no-member
image_center, -1.*angle, 1.)
if rotated_width is None or rotated_height is None:
abs_cos = abs(rotation_matrix[0, 0])
abs_sin = abs(rotation_matrix[0, 1])
if rotated_width is None:
rotated_width = int(height*abs_sin + width*abs_cos)
if rotated_height is None:
rotated_height = int(height*abs_cos + width*abs_sin)
rotation_matrix[0, 2] += rotated_width/2 - image_center[0]
rotation_matrix[1, 2] += rotated_height/2 - image_center[1]
logger.trace("Rotated image: (rotation_matrix: %s", rotation_matrix)
return (cv2.warpAffine(image, # pylint: disable=no-member
rotation_matrix,
(rotated_width, rotated_height)),
rotation_matrix)
# << QUEUE METHODS >> #
def get_item(self):
""" Yield one item from the queue """
item = self.queues["in"].get()
if isinstance(item, dict):
logger.trace("Item in: %s", item["filename"])
else:
logger.trace("Item in: %s", item)
if item == "EOF":
logger.debug("In Queue Exhausted")
# Re-put EOF into queue for other threads
self.queues["in"].put(item)
return item
def get_batch(self):
""" Get items from the queue in batches of
self.batch_size
First item in output tuple indicates whether the
queue is exhausted.
Second item is the batch
Remember to put "EOF" to the out queue after processing
the final batch """
exhausted = False
batch = list()
for _ in range(self.batch_size):
item = self.get_item()
if item == "EOF":
exhausted = True
break
batch.append(item)
logger.trace("Returning batch size: %s", len(batch))
return (exhausted, batch)
# <<< MISC METHODS >>> #
def get_vram_free(self):
""" Return free and total VRAM on card with most VRAM free"""
stats = GPUStats()
vram = stats.get_card_most_free(supports_plaidml=self.supports_plaidml)
logger.verbose("Using device %s with %sMB free of %sMB",
vram["device"],
int(vram["free"]),
int(vram["total"]))
return int(vram["card_id"]), int(vram["free"]), int(vram["total"])
@staticmethod
def set_predetected(width, height):
""" Set a BoundingBox for predetected faces """
# Predetected_face is used for sort tool.
# Landmarks should not be extracted again from predetected faces,
# because face data is lost, resulting in a large variance
# against extract from original image
logger.debug("Setting predetected face")
return [BoundingBox(0, 0, width, height)]