1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/extract/mask/_base.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

306 lines
13 KiB
Python

#!/usr/bin/env python3
""" Base class for Face Masker plugins
Plugins should inherit from this class
See the override methods for which methods are required.
The plugin will receive a :class:`~plugins.extract.pipeline.ExtractMedia` object.
For each source item, the plugin must pass a dict to finalize containing:
>>> {"filename": <filename of source frame>,
>>> "detected_faces": <list of bounding box dicts from lib/plugins/extract/detect/_base>}
"""
from __future__ import annotations
import logging
import typing as T
from dataclasses import dataclass, field
import cv2
import numpy as np
from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa
from lib.align import AlignedFace, transform_image
from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia
if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue
from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType
logger = logging.getLogger(__name__)
@dataclass
class MaskerBatch(ExtractorBatch):
""" Dataclass for holding items flowing through the aligner.
Inherits from :class:`~plugins.extract._base.ExtractorBatch`
Parameters
----------
roi_masks: list
The region of interest masks for the batch
"""
detected_faces: list[DetectedFace] = field(default_factory=list)
roi_masks: list[np.ndarray] = field(default_factory=list)
feed_faces: list[AlignedFace] = field(default_factory=list)
class Masker(Extractor): # pylint:disable=abstract-method
""" Masker plugin _base Object
All Masker plugins must inherit from this class
Parameters
----------
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
model_filename: str
The name of the model file to be loaded
Other Parameters
----------------
configfile: str, optional
Path to a custom configuration ``ini`` file. Default: Use system configfile
See Also
--------
plugins.extract.pipeline : The extraction pipeline for calling plugins
plugins.extract.align : Aligner plugins
plugins.extract._base : Parent class for all extraction plugins
plugins.extract.detect._base : Detector parent class for extraction plugins.
plugins.extract.align._base : Aligner parent class for extraction plugins.
"""
def __init__(self,
git_model_id: int | None = None,
model_filename: str | None = None,
configfile: str | None = None,
instance: int = 0,
**kwargs) -> None:
logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile)
super().__init__(git_model_id,
model_filename,
configfile=configfile,
instance=instance,
**kwargs)
self.input_size = 256 # Override for model specific input_size
self.coverage_ratio = 1.0 # Override for model specific coverage_ratio
self._plugin_type = "mask"
self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-")
self._storage_centering: CenteringType = "face" # Centering to store the mask at
self._storage_size = 128 # Size to store masks at. Leave this at default
logger.debug("Initialized %s", self.__class__.__name__)
def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]:
""" Get items for inputting into the masker from the queue in batches
Items are returned from the ``queue`` in batches of
:attr:`~plugins.extract._base.Extractor.batchsize`
Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted
to ``dict`` for internal processing.
To ensure consistent batch sizes for masker the items are split into separate items for
each :class:`~lib.align.DetectedFace` object.
Remember to put ``'EOF'`` to the out queue after processing
the final batch
Outputs items in the following format. All lists are of length
:attr:`~plugins.extract._base.Extractor.batchsize`:
>>> {'filename': [<filenames of source frames>],
>>> 'detected_faces': [[<lib.align.DetectedFace objects]]}
Parameters
----------
queue : queue.Queue()
The ``queue`` that the plugin will be fed from.
Returns
-------
exhausted, bool
``True`` if queue is exhausted, ``False`` if not
batch, :class:`~plugins.extract._base.ExtractorBatch`
The batch object for the current batch
"""
exhausted = False
batch = MaskerBatch()
idx = 0
while idx < self.batchsize:
item = self.rollover_collector(queue)
if item == "EOF":
logger.trace("EOF received") # type: ignore
exhausted = True
break
# Put frames with no faces into the out queue to keep TQDM consistent
if not item.detected_faces:
self._queues["out"].put(item)
continue
for f_idx, face in enumerate(item.detected_faces):
image = item.get_image_copy(self.color_format)
roi = np.ones((*item.image_size[:2], 1), dtype="float32")
if not item.is_aligned:
# Add the ROI mask to image so we can get the ROI mask with a single warp
image = np.concatenate([image, roi], axis=-1)
feed_face = AlignedFace(face.landmarks_xy,
image=image,
centering=self._storage_centering,
size=self.input_size,
coverage_ratio=self.coverage_ratio,
dtype="float32",
is_aligned=item.is_aligned)
assert feed_face.face is not None
if not item.is_aligned:
# Split roi mask from feed face alpha channel
roi_mask = feed_face.face[..., 3]
feed_face._face = feed_face.face[..., :3] # pylint:disable=protected-access
else:
# We have to do the warp here as AlignedFace did not perform it
roi_mask = transform_image(roi,
feed_face.matrix,
feed_face.size,
padding=feed_face.padding)
batch.roi_masks.append(roi_mask)
batch.detected_faces.append(face)
batch.feed_faces.append(feed_face)
batch.filename.append(item.filename)
idx += 1
if idx == self.batchsize:
frame_faces = len(item.detected_faces)
if f_idx + 1 != frame_faces:
self._rollover = ExtractMedia(
item.filename,
item.image,
detected_faces=item.detected_faces[f_idx + 1:],
is_aligned=item.is_aligned)
logger.trace("Rolled over %s faces of %s to next batch " # type:ignore
"for '%s'", len(self._rollover.detected_faces), frame_faces,
item.filename)
break
if batch:
logger.trace("Returning batch: %s", # type:ignore
{k: len(v) if isinstance(v, (list, np.ndarray)) else v
for k, v in batch.__dict__.items()})
else:
logger.trace(item) # type:ignore
return exhausted, batch
def _predict(self, batch: BatchType) -> MaskerBatch:
""" Just return the masker's predict function """
assert isinstance(batch, MaskerBatch)
assert self.name is not None
try:
# slightly hacky workaround to deal with landmarks based masks:
if self.name.lower() in ("components", "extended"):
feed = np.empty(2, dtype="object")
feed[0] = batch.feed
feed[1] = batch.feed_faces
else:
feed = batch.feed
batch.prediction = self.predict(feed)
return batch
except tf_errors.ResourceExhaustedError as err:
msg = ("You do not have enough GPU memory available to run detection at the "
"selected batch size. You can try a number of things:"
"\n1) Close any other application that is using your GPU (web browsers are "
"particularly bad for this)."
"\n2) Lower the batchsize (the amount of images fed into the model) by "
"editing the plugin settings (GUI: Settings > Configure extract settings, "
"CLI: Edit the file faceswap/config/extract.ini)."
"\n3) Enable 'Single Process' mode.")
raise FaceswapError(msg) from err
def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]:
""" Finalize the output from Masker
This should be called as the final task of each `plugin`.
Pairs the detected faces back up with their original frame before yielding each frame.
Parameters
----------
batch : dict
The final ``dict`` from the `plugin` process. It must contain the `keys`:
``detected_faces``, ``filename``, ``feed_faces``, ``roi_masks``
Yields
------
:class:`~plugins.extract.pipeline.ExtractMedia`
The :attr:`DetectedFaces` list will be populated for this class with the bounding
boxes, landmarks and masks for the detected faces found in the frame.
"""
assert isinstance(batch, MaskerBatch)
for mask, face, feed_face, roi_mask in zip(batch.prediction,
batch.detected_faces,
batch.feed_faces,
batch.roi_masks):
self._crop_out_of_bounds(mask, roi_mask)
face.add_mask(self._storage_name,
mask,
feed_face.adjusted_matrix,
feed_face.interpolators[1],
storage_size=self._storage_size,
storage_centering=self._storage_centering)
del batch.feed
logger.trace("Item out: %s", # type: ignore
{key: val.shape if isinstance(val, np.ndarray) else val
for key, val in batch.__dict__.items()})
for filename, face in zip(batch.filename, batch.detected_faces):
self._output_faces.append(face)
if len(self._output_faces) != self._faces_per_filename[filename]:
continue
output = self._extract_media.pop(filename)
output.add_detected_faces(self._output_faces)
self._output_faces = []
logger.trace("Yielding: (filename: '%s', image: %s, " # type:ignore
"detected_faces: %s)", output.filename, output.image_shape,
len(output.detected_faces))
yield output
# <<< PROTECTED ACCESS METHODS >>> #
@classmethod
def _resize(cls, image: np.ndarray, target_size: int) -> np.ndarray:
""" resize input and output of mask models appropriately """
height, width, channels = image.shape
image_size = max(height, width)
scale = target_size / image_size
if scale == 1.:
return image
method = cv2.INTER_CUBIC if scale > 1. else cv2.INTER_AREA # pylint: disable=no-member
resized = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=method)
resized = resized if channels > 1 else resized[..., None]
return resized
@classmethod
def _crop_out_of_bounds(cls, mask: np.ndarray, roi_mask: np.ndarray) -> None:
""" Un-mask any area of the predicted mask that falls outside of the original frame.
Parameters
----------
masks: :class:`numpy.ndarray`
The predicted masks from the plugin
roi_mask: :class:`numpy.ndarray`
The roi mask. In frame is white, out of frame is black
"""
if np.all(roi_mask):
return # The whole of the face is within the frame
roi_mask = roi_mask[..., None] if mask.ndim == 3 else roi_mask
mask *= roi_mask