1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/extract/detect/_base.py
torzdf 66ed005ef3
Optimize Data Augmentation (#881)
* Move image utils to lib.image
* Add .pylintrc file
* Remove some cv2 pylint ignores
* TrainingData: Load images from disk in batches
* TrainingData: get_landmarks to batch
* TrainingData: transform and flip to batches
* TrainingData: Optimize color augmentation
* TrainingData: Optimize target and random_warp
* TrainingData - Convert _get_closest_match for batching
* TrainingData: Warp To Landmarks optimized
* Save models to threadpoolexecutor
* Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings
* Masks: Set dtype and threshold for lib.masks based on input face
* Docstrings and Documentation
2019-09-24 12:16:05 +01:00

395 lines
16 KiB
Python
Executable file

#!/usr/bin/env python3
""" Base class for Face Detector plugins
All Detector Plugins should inherit from this class.
See the override methods for which methods are required.
For each source frame, the plugin must pass a dict to finalize containing:
>>> {'filename': <filename of source frame>,
>>> 'image': <source image>,
>>> 'detected_faces': <list of DetectedFace objects containing bounding box points}}
To get a :class:`~lib.faces_detect.DetectedFace` object use the function:
>>> face = self.to_detected_face(<face left>, <face top>, <face right>, <face bottom>)
"""
import cv2
import numpy as np
from lib.faces_detect import DetectedFace, rotate_landmarks
from plugins.extract._base import Extractor, logger
class Detector(Extractor):
""" Detector Object
Parent class for all Detector plugins
Parameters
----------
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
model_filename: str
The name of the model file to be loaded
rotation: str, optional
Pass in a single number to use increments of that size up to 360, or pass in a ``list`` of
``ints`` to enumerate exactly what angles to check. Can also pass in ``'on'`` to increment
at 90 degree intervals. Default: ``None``
min_size: int, optional
Filters out faces detected below this size. Length, in pixels across the diagonal of the
bounding box. Set to ``0`` for off. Default: ``0``
Other Parameters
----------------
configfile: str, optional
Path to a custom configuration ``ini`` file. Default: Use system configfile
See Also
--------
plugins.extract.pipeline : The extraction pipeline for calling plugins
plugins.extract.detect : Detector plugins
plugins.extract._base : Parent class for all extraction plugins
plugins.extract.align._base : Aligner parent class for extraction plugins.
"""
def __init__(self, git_model_id=None, model_filename=None,
configfile=None, rotation=None, min_size=0):
logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__,
rotation, min_size)
super().__init__(git_model_id,
model_filename,
configfile=configfile)
self.rotation = self._get_rotation_angles(rotation)
self.min_size = min_size
self._plugin_type = "detect"
logger.debug("Initialized _base %s", self.__class__.__name__)
# <<< QUEUE METHODS >>> #
def get_batch(self, queue):
""" Get items for inputting to the detector plugin in batches
Items are returned from the ``queue`` in batches of
:attr:`~plugins.extract._base.Extractor.batchsize`
Remember to put ``'EOF'`` to the out queue after processing
the final batch
Outputs items in the following format. All lists are of length
:attr:`~plugins.extract._base.Extractor.batchsize`:
>>> {'filename': [<filenames of source frames>],
>>> 'image': [<source images>],
>>> 'scaled_image': <np.array of images standardized for prediction>,
>>> 'scale': [<scaling factors for each image>],
>>> 'pad': [<padding for each image>],
>>> 'detected_faces': [[<lib.faces_detect.DetectedFace objects]]}
Parameters
----------
queue : queue.Queue()
The ``queue`` that the batch will be fed from. This will be a queue that loads
images.
Returns
-------
exhausted, bool
``True`` if queue is exhausted, ``False`` if not.
batch, dict
A dictionary of lists of :attr:`~plugins.extract._base.Extractor.batchsize`.
"""
exhausted = False
batch = dict()
for _ in range(self.batchsize):
item = self._get_item(queue)
if item == "EOF":
exhausted = True
break
for key, val in item.items():
batch.setdefault(key, []).append(val)
scaled_image, scale, pad = self._compile_detection_image(item["image"])
batch.setdefault("scaled_image", []).append(scaled_image)
batch.setdefault("scale", []).append(scale)
batch.setdefault("pad", []).append(pad)
if batch:
batch["scaled_image"] = np.array(batch["scaled_image"], dtype="float32")
logger.trace("Returning batch: %s", {k: v.shape if isinstance(v, np.ndarray) else v
for k, v in batch.items() if k != "image"})
else:
logger.trace(item)
return exhausted, batch
# <<< FINALIZE METHODS>>> #
def finalize(self, batch):
""" Finalize the output from Detector
This should be called as the final task of each ``plugin``.
It strips unneeded items from the :attr:`batch` ``dict`` and performs standard final
processing on each item
Outputs items in the format:
>>> {'image': [<original frame>],
>>> 'filename': [<frame filename>),
>>> 'detected_faces': [<lib.faces_detect.DetectedFace objects>]}
Parameters
----------
batch : dict
The final ``dict`` from the `plugin` process. It must contain the keys ``image``,
``filename``, ``faces``
Yields
------
dict
A ``dict`` for each frame containing the ``image``, ``filename`` and ``list`` of
``detected_faces``
"""
if not isinstance(batch, dict):
logger.trace("Item out: %s", batch)
return batch
logger.trace("Item out: %s", {k: v.shape if isinstance(v, np.ndarray) else v
for k, v in batch.items()})
batch_faces = [[self.to_detected_face(face[0], face[1], face[2], face[3])
for face in faces]
for faces in batch["prediction"]]
# Rotations
if any(m.any() for m in batch["rotmat"]) and any(batch_faces):
batch_faces = [[self._rotate_rect(face, rotmat) if rotmat.any() else face
for face in faces]
for faces, rotmat in zip(batch_faces, batch["rotmat"])]
# Scale back out to original frame
batch["detected_faces"] = [[self.to_detected_face((face.left - pad[0]) / scale,
(face.top - pad[1]) / scale,
(face.right - pad[0]) / scale,
(face.bottom - pad[1]) / scale)
for face in faces]
for scale, pad, faces in zip(batch["scale"],
batch["pad"],
batch_faces)]
# Remove zero sized faces
self._remove_zero_sized_faces(batch)
if self.min_size > 0 and batch.get("detected_faces", None):
batch["detected_faces"] = self._filter_small_faces(batch["detected_faces"])
self._remove_invalid_keys(batch, ("detected_faces", "filename", "image"))
batch = self._dict_lists_to_list_dicts(batch)
for item in batch:
logger.trace("final output: %s", {k: v.shape if isinstance(v, np.ndarray) else v
for k, v in item.items()})
yield item
@staticmethod
def to_detected_face(left, top, right, bottom):
""" Return a :class:`~lib.faces_detect.DetectedFace` object for the bounding box """
return DetectedFace(x=int(round(left)),
w=int(round(right - left)),
y=int(round(top)),
h=int(round(bottom - top)))
# <<< PROTECTED ACCESS METHODS >>> #
# <<< PREDICT WRAPPER >>> #
def _predict(self, batch):
""" Wrap models predict function in rotations """
batch["rotmat"] = [np.array([]) for _ in range(len(batch["feed"]))]
found_faces = [np.array([]) for _ in range(len(batch["feed"]))]
for angle in self.rotation:
# Rotate the batch and insert placeholders for already found faces
self._rotate_batch(batch, angle)
batch = self.predict(batch)
if angle != 0 and any([face.any() for face in batch["prediction"]]):
logger.verbose("found face(s) by rotating image %s degrees", angle)
found_faces = [face if not found.any() else found
for face, found in zip(batch["prediction"], found_faces)]
if all([face.any() for face in found_faces]):
logger.trace("Faces found for all images")
break
batch["prediction"] = found_faces
logger.trace("detect_prediction output: (filenames: %s, prediction: %s, rotmat: %s)",
batch["filename"], batch["prediction"], batch["rotmat"])
return batch
# <<< DETECTION IMAGE COMPILATION METHODS >>> #
def _compile_detection_image(self, input_image):
""" Compile the detection image for feeding into the model"""
image = self._convert_color(input_image)
image_size = image.shape[:2]
scale = self._set_scale(image_size)
pad = self._set_padding(image_size, scale)
image = self._scale_image(image, image_size, scale)
image = self._pad_image(image)
logger.trace("compiled: (images shape: %s, scale: %s, pad: %s)", image.shape, scale, pad)
return image, scale, pad
def _set_scale(self, image_size):
""" Set the scale factor for incoming image """
scale = self.input_size / max(image_size)
logger.trace("Detector scale: %s", scale)
return scale
def _set_padding(self, image_size, scale):
""" Set the image padding for non-square images """
pad_left = int(self.input_size - int(image_size[1] * scale)) // 2
pad_top = int(self.input_size - int(image_size[0] * scale)) // 2
return pad_left, pad_top
@staticmethod
def _scale_image(image, image_size, scale):
""" Scale the image and optional pad to given size """
interpln = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint:disable=no-member
if scale != 1.0:
dims = (int(image_size[1] * scale), int(image_size[0] * scale))
logger.trace("Resizing detection image from %s to %s. Scale=%s",
"x".join(str(i) for i in reversed(image_size)),
"x".join(str(i) for i in dims), scale)
image = cv2.resize(image, dims, interpolation=interpln) # pylint:disable=no-member
logger.trace("Resized image shape: %s", image.shape)
return image
def _pad_image(self, image):
""" Pad a resized image to input size """
height, width = image.shape[:2]
if width < self.input_size or height < self.input_size:
pad_l = (self.input_size - width) // 2
pad_r = (self.input_size - width) - pad_l
pad_t = (self.input_size - height) // 2
pad_b = (self.input_size - height) - pad_t
image = cv2.copyMakeBorder( # pylint:disable=no-member
image,
pad_t,
pad_b,
pad_l,
pad_r,
cv2.BORDER_CONSTANT) # pylint:disable=no-member
logger.trace("Padded image shape: %s", image.shape)
return image
# <<< FINALIZE METHODS >>> #
@staticmethod
def _remove_zero_sized_faces(batch):
""" Remove items from dict where detected face is of zero size
or face falls entirely outside of image """
dims = [img.shape[:2] for img in batch["image"]]
logger.trace("image dims: %s", dims)
batch["detected_faces"] = [[face for face in faces
if face.right > 0 and face.left < dim[1]
and face.bottom > 0 and face.top < dim[0]]
for dim, faces in zip(dims,
batch.get("detected_faces", list()))]
def _filter_small_faces(self, detected_faces):
""" Filter out any faces smaller than the min size threshold """
retval = []
for faces in detected_faces:
this_image = []
for face in faces:
face_size = (face.w ** 2 + face.h ** 2) ** 0.5
if face_size < self.min_size:
logger.debug("Removing detected face: (face_size: %s, min_size: %s",
face_size, self.min_size)
continue
this_image.append(face)
retval.append(this_image)
return retval
# <<< IMAGE ROTATION METHODS >>> #
@staticmethod
def _get_rotation_angles(rotation):
""" Set the rotation angles. Includes backwards compatibility for the
'on' and 'off' options:
- 'on' - increment 90 degrees
- 'off' - disable
- 0 is prepended to the list, as whatever happens, we want to
scan the image in it's upright state """
rotation_angles = [0]
if not rotation or rotation.lower() == "off":
logger.debug("Not setting rotation angles")
return rotation_angles
if rotation.lower() == "on":
rotation_angles.extend(range(90, 360, 90))
else:
passed_angles = [int(angle)
for angle in rotation.split(",")
if int(angle) != 0]
if len(passed_angles) == 1:
rotation_step_size = passed_angles[0]
rotation_angles.extend(range(rotation_step_size,
360,
rotation_step_size))
elif len(passed_angles) > 1:
rotation_angles.extend(passed_angles)
logger.debug("Rotation Angles: %s", rotation_angles)
return rotation_angles
def _rotate_batch(self, batch, angle):
""" Rotate images in a batch by given angle
if any faces have already been detected for a batch, store the existing rotation
matrix and replace the feed image with a placeholder """
if angle == 0:
# Set the initial batch so we always rotate from zero
batch["initial_feed"] = batch["feed"].copy()
return
retval = dict()
for img, faces, rotmat in zip(batch["initial_feed"], batch["prediction"], batch["rotmat"]):
if faces.any():
image = np.zeros_like(img)
matrix = rotmat
else:
image, matrix = self._rotate_image_by_angle(img, angle)
retval.setdefault("feed", []).append(image)
retval.setdefault("rotmat", []).append(matrix)
batch["feed"] = np.array(retval["feed"], dtype="float32")
batch["rotmat"] = retval["rotmat"]
@staticmethod
def _rotate_rect(bounding_box, rotation_matrix):
""" Rotate a bounding box dict based on the rotation_matrix"""
logger.trace("Rotating bounding box")
bounding_box = rotate_landmarks(bounding_box, rotation_matrix)
return bounding_box
def _rotate_image_by_angle(self, image, angle):
""" Rotate an image by a given angle.
From: https://stackoverflow.com/questions/22041699 """
logger.trace("Rotating image: (image: %s, angle: %s)", image.shape, angle)
channels_first = image.shape[0] <= 4
if channels_first:
image = np.moveaxis(image, 0, 2)
height, width = image.shape[:2]
image_center = (width/2, height/2)
rotation_matrix = cv2.getRotationMatrix2D( # pylint: disable=no-member
image_center, -1.*angle, 1.)
rotation_matrix[0, 2] += self.input_size / 2 - image_center[0]
rotation_matrix[1, 2] += self.input_size / 2 - image_center[1]
logger.trace("Rotated image: (rotation_matrix: %s", rotation_matrix)
image = cv2.warpAffine(image, # pylint: disable=no-member
rotation_matrix,
(self.input_size, self.input_size))
if channels_first:
image = np.moveaxis(image, 2, 0)
return image, rotation_matrix