mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* tools.alignments - add export job * plugins.extract: Update __repr__ for ExtractorBatch dataclass * plugins.extract: Initial implementation of external import plugins * plugins.extract: Disable lm masks on ROI alignment data import * lib.align: Add `landmark_type` property to AlignedFace and return dummy data for ROI Landmarks pose estimate * plugins.extract: Add centering config item for align import and fix filename mapping for images * plugins.extract: Log warning on downstream plugins on limited alignment data * tools: Fix plugins for 4 point ROI landmarks (alignments, sort, mask) * tools.manual: Fix for 2D-4 ROI landmarks * training: Fix for 4 point ROI landmarks * lib.convert: Average color plugin. Avoid divide by zero errors * extract - external: - Default detector to 'external' when importing alignments - Handle different frame origin co-ordinates * alignments: Store video extension in alignments file * plugins.extract.external: Handle video file keys * plugins.extract.external: Output warning if missing data * locales + docs * plugins.extract.align.external: Roll the corner points to top-left for different origins * Clean up * linting fix
878 lines
38 KiB
Python
878 lines
38 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Return a requested detector/aligner/masker pipeline
|
|
|
|
Tensorflow does not like to release GPU VRAM, so parallel plugins need to be managed to work
|
|
together.
|
|
|
|
This module sets up a pipeline for the extraction workflow, loading detect, align and mask
|
|
plugins either in parallel or in series, giving easy access to input and output.
|
|
"""
|
|
from __future__ import annotations
|
|
import logging
|
|
import os
|
|
import typing as T
|
|
|
|
from lib.align import LandmarkType
|
|
from lib.gpu_stats import GPUStats
|
|
from lib.logger import parse_class_init
|
|
from lib.queue_manager import EventQueue, queue_manager, QueueEmpty
|
|
from lib.serializer import get_serializer
|
|
from lib.utils import get_backend, FaceswapError
|
|
from plugins.plugin_loader import PluginLoader
|
|
|
|
if T.TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
from ._base import Extractor as PluginExtractor
|
|
from .align._base import Aligner
|
|
from .align.external import Align as AlignImport
|
|
from .detect._base import Detector
|
|
from .detect.external import Detect as DetectImport
|
|
from .mask._base import Masker
|
|
from .recognition._base import Identity
|
|
from . import ExtractMedia
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_INSTANCES = -1 # Tracking for multiple instances of pipeline
|
|
|
|
|
|
def _get_instance():
|
|
""" Increment the global :attr:`_INSTANCES` and obtain the current instance value """
|
|
global _INSTANCES # pylint:disable=global-statement
|
|
_INSTANCES += 1
|
|
return _INSTANCES
|
|
|
|
|
|
class Extractor():
|
|
""" Creates a :mod:`~plugins.extract.detect`/:mod:`~plugins.extract.align``/\
|
|
:mod:`~plugins.extract.mask` pipeline and yields results frame by frame from the
|
|
:attr:`detected_faces` generator
|
|
|
|
:attr:`input_queue` is dynamically set depending on the current :attr:`phase` of extraction
|
|
|
|
Parameters
|
|
----------
|
|
detector: str or ``None``
|
|
The name of a detector plugin as exists in :mod:`plugins.extract.detect`
|
|
aligner: str or ``None``
|
|
The name of an aligner plugin as exists in :mod:`plugins.extract.align`
|
|
masker: str or list or ``None``
|
|
The name of a masker plugin(s) as exists in :mod:`plugins.extract.mask`.
|
|
This can be a single masker or a list of multiple maskers
|
|
recognition: str or ``None``
|
|
The name of the recognition plugin to use. ``None`` to not do face recognition.
|
|
Default: ``None``
|
|
configfile: str, optional
|
|
The path to a custom ``extract.ini`` configfile. If ``None`` then the system
|
|
:file:`config/extract.ini` file will be used.
|
|
multiprocess: bool, optional
|
|
Whether to attempt processing the plugins in parallel. This may get overridden
|
|
internally depending on the plugin combination. Default: ``False``
|
|
exclude_gpus: list, optional
|
|
A list of indices correlating to connected GPUs that Tensorflow should not use. Pass
|
|
``None`` to not exclude any GPUs. Default: ``None``
|
|
rotate_images: str, optional
|
|
Used to set the :attr:`plugins.extract.detect.rotation` attribute. Pass in a single number
|
|
to use increments of that size up to 360, or pass in a ``list`` of ``ints`` to enumerate
|
|
exactly what angles to check. Can also pass in ``'on'`` to increment at 90 degree
|
|
intervals. Default: ``None``
|
|
min_size: int, optional
|
|
Used to set the :attr:`plugins.extract.detect.min_size` attribute. Filters out faces
|
|
detected below this size. Length, in pixels across the diagonal of the bounding box. Set
|
|
to ``0`` for off. Default: ``0``
|
|
normalize_method: {`None`, 'clahe', 'hist', 'mean'}, optional
|
|
Used to set the :attr:`plugins.extract.align.normalize_method` attribute. Normalize the
|
|
images fed to the aligner.Default: ``None``
|
|
re_feed: int
|
|
The number of times to re-feed a slightly adjusted bounding box into the aligner.
|
|
Default: `0`
|
|
re_align: bool, optional
|
|
``True`` to obtain landmarks by passing the initially aligned face back through the
|
|
aligner. Default ``False``
|
|
disable_filter: bool, optional
|
|
Disable all aligner filters regardless of config option. Default: ``False``
|
|
|
|
Attributes
|
|
----------
|
|
phase: str
|
|
The current phase that the pipeline is running. Used in conjunction with :attr:`passes` and
|
|
:attr:`final_pass` to indicate to the caller which phase is being processed
|
|
"""
|
|
def __init__(self,
|
|
detector: str | None,
|
|
aligner: str | None,
|
|
masker: str | list[str] | None,
|
|
recognition: str | None = None,
|
|
configfile: str | None = None,
|
|
multiprocess: bool = False,
|
|
exclude_gpus: list[int] | None = None,
|
|
rotate_images: str | None = None,
|
|
min_size: int = 0,
|
|
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
|
|
re_feed: int = 0,
|
|
re_align: bool = False,
|
|
disable_filter: bool = False) -> None:
|
|
logger.debug(parse_class_init(locals()))
|
|
self._instance = _get_instance()
|
|
maskers = [T.cast(str | None,
|
|
masker)] if not isinstance(masker, list) else T.cast(list[str | None],
|
|
masker)
|
|
self._flow = self._set_flow(detector, aligner, maskers, recognition)
|
|
self._exclude_gpus = exclude_gpus
|
|
# We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting
|
|
# for queue) at each point. Adding more just stacks RAM with no speed benefit.
|
|
self._queue_size = 1
|
|
# TODO Calculate scaling for more plugins than currently exist in _parallel_scaling
|
|
self._scaling_fallback = 0.4
|
|
self._vram_stats = self._get_vram_stats()
|
|
self._detect = self._load_detect(detector, aligner, rotate_images, min_size, configfile)
|
|
self._align = self._load_align(aligner,
|
|
configfile,
|
|
normalize_method,
|
|
re_feed,
|
|
re_align,
|
|
disable_filter)
|
|
self._recognition = self._load_recognition(recognition, configfile)
|
|
self._mask = [self._load_mask(mask, configfile) for mask in maskers]
|
|
self._is_parallel = self._set_parallel_processing(multiprocess)
|
|
self._phases = self._set_phases(multiprocess)
|
|
self._phase_index = 0
|
|
self._set_extractor_batchsize()
|
|
self._queues = self._add_queues()
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def input_queue(self) -> EventQueue:
|
|
""" queue: Return the correct input queue depending on the current phase
|
|
|
|
The input queue is the entry point into the extraction pipeline. An :class:`ExtractMedia`
|
|
object should be put to the queue.
|
|
|
|
For detect/single phase operations the :attr:`ExtractMedia.filename` and
|
|
:attr:`~ExtractMedia.image` attributes should be populated.
|
|
|
|
For align/mask (2nd/3rd pass operations) the :attr:`ExtractMedia.detected_faces` should
|
|
also be populated by calling :func:`ExtractMedia.set_detected_faces`.
|
|
"""
|
|
qname = f"extract{self._instance}_{self._current_phase[0]}_in"
|
|
retval = self._queues[qname]
|
|
logger.trace("%s: %s", qname, retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def passes(self) -> int:
|
|
""" int: Returns the total number of passes the extractor needs to make.
|
|
|
|
This is calculated on several factors (vram available, plugin choice,
|
|
:attr:`multiprocess` etc.). It is useful for iterating over the pipeline
|
|
and handling accordingly.
|
|
|
|
Example
|
|
-------
|
|
>>> for phase in extractor.passes:
|
|
>>> if phase == 1:
|
|
>>> extract_media = ExtractMedia("path/to/image/file", image)
|
|
>>> extractor.input_queue.put(extract_media)
|
|
>>> else:
|
|
>>> extract_media.set_image(image)
|
|
>>> extractor.input_queue.put(extract_media)
|
|
"""
|
|
retval = len(self._phases)
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def phase_text(self) -> str:
|
|
""" str: The plugins that are running in the current phase, formatted for info text
|
|
output. """
|
|
plugin_types = set(self._get_plugin_type_and_index(phase)[0]
|
|
for phase in self._current_phase)
|
|
retval = ", ".join(plugin_type.title() for plugin_type in list(plugin_types))
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def final_pass(self) -> bool:
|
|
""" bool, Return ``True`` if this is the final extractor pass otherwise ``False``
|
|
|
|
Useful for iterating over the pipeline :attr:`passes` or :func:`detected_faces` and
|
|
handling accordingly.
|
|
|
|
Example
|
|
-------
|
|
>>> for face in extractor.detected_faces():
|
|
>>> if extractor.final_pass:
|
|
>>> <do final processing>
|
|
>>> else:
|
|
>>> extract_media.set_image(image)
|
|
>>> <do intermediate processing>
|
|
>>> extractor.input_queue.put(extract_media)
|
|
"""
|
|
retval = self._phase_index == len(self._phases) - 1
|
|
logger.trace(retval) # type:ignore[attr-defined]
|
|
return retval
|
|
|
|
@property
|
|
def aligner(self) -> Aligner:
|
|
""" The currently selected aligner plugin """
|
|
assert self._align is not None
|
|
return self._align
|
|
|
|
@property
|
|
def recognition(self) -> Identity:
|
|
""" The currently selected recognition plugin """
|
|
assert self._recognition is not None
|
|
return self._recognition
|
|
|
|
def reset_phase_index(self) -> None:
|
|
""" Reset the current phase index back to 0. Used for when batch processing is used in
|
|
extract. """
|
|
self._phase_index = 0
|
|
|
|
def set_batchsize(self,
|
|
plugin_type: T.Literal["align", "detect"],
|
|
batchsize: int) -> None:
|
|
""" Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`.
|
|
|
|
This should be set prior to :func:`launch` if the batch size is to be manually overridden
|
|
|
|
Parameters
|
|
----------
|
|
plugin_type: {'align', 'detect'}
|
|
The plugin_type to be overridden
|
|
batchsize: int
|
|
The batch size to use for this plugin type
|
|
"""
|
|
logger.debug("Overriding batchsize for plugin_type: %s to: %s", plugin_type, batchsize)
|
|
plugin = getattr(self, f"_{plugin_type}")
|
|
plugin.batchsize = batchsize
|
|
|
|
def launch(self) -> None:
|
|
""" Launches the plugin(s)
|
|
|
|
This launches the plugins held in the pipeline, and should be called at the beginning
|
|
of each :attr:`phase`. To ensure VRAM is conserved, It will only launch the plugin(s)
|
|
required for the currently running phase
|
|
|
|
Example
|
|
-------
|
|
>>> for phase in extractor.passes:
|
|
>>> extractor.launch():
|
|
>>> <do processing>
|
|
"""
|
|
for phase in self._current_phase:
|
|
self._launch_plugin(phase)
|
|
|
|
def detected_faces(self) -> Generator[ExtractMedia, None, None]:
|
|
""" Generator that returns results, frame by frame from the extraction pipeline
|
|
|
|
This is the exit point for the extraction pipeline and is used to obtain the output
|
|
of any pipeline :attr:`phase`
|
|
|
|
Yields
|
|
------
|
|
faces: :class:`~plugins.extract.extract_media.ExtractMedia`
|
|
The populated extracted media object.
|
|
|
|
Example
|
|
-------
|
|
>>> for extract_media in extractor.detected_faces():
|
|
>>> filename = extract_media.filename
|
|
>>> image = extract_media.image
|
|
>>> detected_faces = extract_media.detected_faces
|
|
"""
|
|
logger.debug("Running Detection. Phase: '%s'", self._current_phase)
|
|
# If not multiprocessing, intercept the align in queue for
|
|
# detection phase
|
|
out_queue = self._output_queue
|
|
while True:
|
|
try:
|
|
self._check_and_raise_error()
|
|
faces = out_queue.get(True, 1)
|
|
if faces == "EOF":
|
|
break
|
|
except QueueEmpty:
|
|
continue
|
|
yield faces
|
|
|
|
self._join_threads()
|
|
if self.final_pass:
|
|
for plugin in self._all_plugins:
|
|
plugin.on_completion()
|
|
logger.debug("Detection Complete")
|
|
else:
|
|
self._phase_index += 1
|
|
logger.debug("Switching to phase: %s", self._current_phase)
|
|
|
|
def _disable_lm_maskers(self) -> None:
|
|
""" Disable any 68 point landmark based maskers if alignment data is not 2D 68
|
|
point landmarks and update the process flow/phases accordingly """
|
|
logger.warning("Alignment data is not 68 point 2D landmarks. Some Faceswap functionality "
|
|
"will be unavailable for these faces")
|
|
|
|
rem_maskers = [m.name for m in self._mask
|
|
if m is not None and m.landmark_type == LandmarkType.LM_2D_68]
|
|
self._mask = [m for m in self._mask if m is None or m.name not in rem_maskers]
|
|
|
|
self._flow = [
|
|
item for item in self._flow
|
|
if not item.startswith("mask")
|
|
or item.startswith("mask") and int(item.rsplit("_", maxsplit=1)[-1]) < len(self._mask)]
|
|
|
|
self._phases = [[s for s in p if s in self._flow] for p in self._phases
|
|
if any(t in p for t in self._flow)]
|
|
|
|
for queue in self._queues:
|
|
queue_manager.del_queue(queue)
|
|
del self._queues
|
|
self._queues = self._add_queues()
|
|
|
|
logger.warning("The following maskers have been disabled due to unsupported landmarks: %s",
|
|
rem_maskers)
|
|
|
|
def import_data(self, input_location: str) -> None:
|
|
""" Import json data to the detector and/or aligner if 'import' plugin has been selected
|
|
|
|
Parameters
|
|
----------
|
|
input_location: str
|
|
Full path to the input location for the extract process
|
|
"""
|
|
assert self._detect is not None
|
|
import_plugins: list[DetectImport | AlignImport] = [
|
|
p for p in (self._detect, self.aligner) # type:ignore[misc]
|
|
if T.cast(str, p.name).lower() == "external"]
|
|
|
|
if not import_plugins:
|
|
return
|
|
|
|
align_origin = None
|
|
assert self.aligner.name is not None
|
|
if self.aligner.name.lower() == "external":
|
|
align_origin = self.aligner.config["origin"]
|
|
|
|
logger.info("Importing external data for %s from json file...",
|
|
" and ".join([p.__class__.__name__ for p in import_plugins]))
|
|
|
|
folder = input_location
|
|
folder = folder if os.path.isdir(folder) else os.path.dirname(folder)
|
|
|
|
last_fname = ""
|
|
is_68_point = True
|
|
for plugin in import_plugins:
|
|
plugin_type = plugin.__class__.__name__
|
|
path = os.path.join(folder, plugin.config["file_name"])
|
|
if not os.path.isfile(path):
|
|
raise FaceswapError(f"{plugin_type} import file could not be found at '{path}'")
|
|
|
|
if path != last_fname: # Different import file for aligner data
|
|
last_fname = path
|
|
data = get_serializer("json").load(path)
|
|
|
|
if plugin_type == "Detect":
|
|
plugin.import_data(data, align_origin) # type:ignore[call-arg]
|
|
else:
|
|
plugin.import_data(data) # type:ignore[call-arg]
|
|
is_68_point = plugin.landmark_type == LandmarkType.LM_2D_68 # type:ignore[union-attr] # noqa:E501 # pylint:disable="line-too-long"
|
|
|
|
if not is_68_point:
|
|
self._disable_lm_maskers()
|
|
|
|
logger.info("Imported external data")
|
|
|
|
# <<< INTERNAL METHODS >>> #
|
|
@property
|
|
def _parallel_scaling(self) -> dict[int, float]:
|
|
""" dict: key is number of parallel plugins being loaded, value is the scaling factor that
|
|
the total base vram for those plugins should be scaled by
|
|
|
|
Notes
|
|
-----
|
|
VRAM for parallel plugins does not stack in a linear manner. Calculating the precise
|
|
scaling for any given plugin combination is non trivial, however the following are
|
|
calculations based on running 2-5 plugins in parallel using s3fd, fan, unet, vgg-clear
|
|
and vgg-obstructed. The worst ratio is selected for each combination, plus a little extra
|
|
to ensure that vram is not used up.
|
|
|
|
If OOM errors are being reported, then these ratios should be relaxed some more
|
|
"""
|
|
retval = {0: 1.0,
|
|
1: 1.0,
|
|
2: 0.7,
|
|
3: 0.55,
|
|
4: 0.5,
|
|
5: 0.4}
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _vram_per_phase(self) -> dict[str, float]:
|
|
""" dict: The amount of vram required for each phase in :attr:`_flow`. """
|
|
retval = {}
|
|
for phase in self._flow:
|
|
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
|
attr = getattr(self, f"_{plugin_type}")
|
|
attr = attr[idx] if idx is not None else attr
|
|
retval[phase] = attr.vram
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _total_vram_required(self) -> float:
|
|
""" Return vram required for all phases plus the buffer """
|
|
vrams = self._vram_per_phase
|
|
vram_required_count = sum(1 for p in vrams.values() if p > 0)
|
|
logger.debug("VRAM requirements: %s. Plugins requiring VRAM: %s",
|
|
vrams, vram_required_count)
|
|
retval = (sum(vrams.values()) *
|
|
self._parallel_scaling.get(vram_required_count, self._scaling_fallback))
|
|
logger.debug("Total VRAM required: %s", retval)
|
|
return retval
|
|
|
|
@property
|
|
def _current_phase(self) -> list[str]:
|
|
""" list: The current phase from :attr:`_phases` that is running through the extractor. """
|
|
retval = self._phases[self._phase_index]
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _final_phase(self) -> str:
|
|
""" Return the final phase from the flow list """
|
|
retval = self._flow[-1]
|
|
logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _output_queue(self) -> EventQueue:
|
|
""" Return the correct output queue depending on the current phase """
|
|
if self.final_pass:
|
|
qname = f"extract{self._instance}_{self._final_phase}_out"
|
|
else:
|
|
qname = f"extract{self._instance}_{self._phases[self._phase_index + 1][0]}_in"
|
|
retval = self._queues[qname]
|
|
logger.trace("%s: %s", qname, retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _all_plugins(self) -> list[PluginExtractor]:
|
|
""" Return list of all plugin objects in this pipeline """
|
|
retval = []
|
|
for phase in self._flow:
|
|
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
|
attr = getattr(self, f"_{plugin_type}")
|
|
attr = attr[idx] if idx is not None else attr
|
|
retval.append(attr)
|
|
logger.trace("All Plugins: %s", retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _active_plugins(self) -> list[PluginExtractor]:
|
|
""" Return the plugins that are currently active based on pass """
|
|
retval = []
|
|
for phase in self._current_phase:
|
|
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
|
attr = getattr(self, f"_{plugin_type}")
|
|
retval.append(attr[idx] if idx is not None else attr)
|
|
logger.trace("Active plugins: %s", retval) # type: ignore
|
|
return retval
|
|
|
|
@staticmethod
|
|
def _set_flow(detector: str | None,
|
|
aligner: str | None,
|
|
masker: list[str | None],
|
|
recognition: str | None) -> list[str]:
|
|
""" Set the flow list based on the input plugins
|
|
|
|
Parameters
|
|
----------
|
|
detector: str or ``None``
|
|
The name of a detector plugin as exists in :mod:`plugins.extract.detect`
|
|
aligner: str or ``None
|
|
The name of an aligner plugin as exists in :mod:`plugins.extract.align`
|
|
masker: str or list or ``None
|
|
The name of a masker plugin(s) as exists in :mod:`plugins.extract.mask`.
|
|
This can be a single masker or a list of multiple maskers
|
|
recognition: str or ``None``
|
|
The name of the recognition plugin to use. ``None`` to not do face recognition.
|
|
"""
|
|
logger.debug("detector: %s, aligner: %s, masker: %s recognition: %s",
|
|
detector, aligner, masker, recognition)
|
|
retval = []
|
|
if detector is not None and detector.lower() != "none":
|
|
retval.append("detect")
|
|
if aligner is not None and aligner.lower() != "none":
|
|
retval.append("align")
|
|
if recognition is not None and recognition.lower() != "none":
|
|
retval.append("recognition")
|
|
retval.extend([f"mask_{idx}"
|
|
for idx, mask in enumerate(masker)
|
|
if mask is not None and mask.lower() != "none"])
|
|
logger.debug("flow: %s", retval)
|
|
return retval
|
|
|
|
@staticmethod
|
|
def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]:
|
|
""" Obtain the plugin type and index for the plugin for the given flow phase.
|
|
|
|
When multiple plugins for the same phase are allowed (e.g. Mask) this will return
|
|
the plugin type and the index of the plugin required. If only one plugin is allowed
|
|
then the plugin type will be returned and the index will be ``None``.
|
|
|
|
Parameters
|
|
----------
|
|
flow_phase: str
|
|
The phase within :attr:`_flow` that is to have the plugin type and index returned
|
|
|
|
Returns
|
|
-------
|
|
plugin_type: str
|
|
The plugin type for the given flow phase
|
|
index: int
|
|
The index of this plugin type within the flow, if there are multiple plugins in use
|
|
otherwise ``None`` if there is only 1 plugin in use for the given phase
|
|
"""
|
|
sidx = flow_phase.split("_")[-1]
|
|
if sidx.isdigit():
|
|
idx: int | None = int(sidx)
|
|
plugin_type = "_".join(flow_phase.split("_")[:-1])
|
|
else:
|
|
plugin_type = flow_phase
|
|
idx = None
|
|
return plugin_type, idx
|
|
|
|
def _add_queues(self) -> dict[str, EventQueue]:
|
|
""" Add the required processing queues to Queue Manager """
|
|
queues = {}
|
|
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
|
|
tasks.append(f"extract{self._instance}_{self._final_phase}_out")
|
|
for task in tasks:
|
|
# Limit queue size to avoid stacking ram
|
|
queue_manager.add_queue(task, maxsize=self._queue_size)
|
|
queues[task] = queue_manager.get_queue(task)
|
|
logger.debug("Queues: %s", queues)
|
|
return queues
|
|
|
|
@staticmethod
|
|
def _get_vram_stats() -> dict[str, int | str]:
|
|
""" Obtain statistics on available VRAM and subtract a constant buffer from available vram.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
Statistics on available VRAM
|
|
"""
|
|
vram_buffer = 256 # Leave a buffer for VRAM allocation
|
|
gpu_stats = GPUStats()
|
|
stats = gpu_stats.get_card_most_free()
|
|
retval: dict[str, int | str] = {"count": gpu_stats.device_count,
|
|
"device": stats.device,
|
|
"vram_free": int(stats.free - vram_buffer),
|
|
"vram_total": int(stats.total)}
|
|
logger.debug(retval)
|
|
return retval
|
|
|
|
def _set_parallel_processing(self, multiprocess: bool) -> bool:
|
|
""" Set whether to run detect, align, and mask together or separately.
|
|
|
|
Parameters
|
|
----------
|
|
multiprocess: bool
|
|
``True`` if the single-process command line flag has not been set otherwise ``False``
|
|
"""
|
|
if not multiprocess:
|
|
logger.debug("Parallel processing disabled by cli.")
|
|
return False
|
|
|
|
if self._vram_stats["count"] == 0:
|
|
logger.debug("No GPU detected. Enabling parallel processing.")
|
|
return True
|
|
|
|
logger.verbose("%s - %sMB free of %sMB", # type: ignore
|
|
self._vram_stats["device"],
|
|
self._vram_stats["vram_free"],
|
|
self._vram_stats["vram_total"])
|
|
if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
|
|
logger.warning("Not enough free VRAM for parallel processing. "
|
|
"Switching to serial")
|
|
return False
|
|
return True
|
|
|
|
def _set_phases(self, multiprocess: bool) -> list[list[str]]:
|
|
""" If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit
|
|
into VRAM, otherwise return the single flow.
|
|
|
|
Parameters
|
|
----------
|
|
multiprocess: bool
|
|
``True`` if the single-process command line flag has not been set otherwise ``False``
|
|
|
|
Returns
|
|
-------
|
|
list:
|
|
The jobs to be undertaken split into phases that fit into GPU RAM
|
|
"""
|
|
phases: list[list[str]] = []
|
|
current_phase: list[str] = []
|
|
available = T.cast(int, self._vram_stats["vram_free"])
|
|
for phase in self._flow:
|
|
num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0])
|
|
num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0
|
|
scaling = self._parallel_scaling.get(num_plugins, self._scaling_fallback)
|
|
required = sum(self._vram_per_phase[p] for p in current_phase + [phase]) * scaling
|
|
logger.debug("Num plugins for phase: %s, scaling: %s, vram required: %s",
|
|
num_plugins, scaling, required)
|
|
if required <= available and multiprocess:
|
|
logger.debug("Required: %s, available: %s. Adding phase '%s' to current phase: %s",
|
|
required, available, phase, current_phase)
|
|
current_phase.append(phase)
|
|
elif len(current_phase) == 0 or not multiprocess:
|
|
# Amount of VRAM required to run a single plugin is greater than available. We add
|
|
# it anyway, and hope it will run with warnings, as the alternative is to not run
|
|
# at all.
|
|
# This will also run if forcing single process
|
|
logger.debug("Required: %s, available: %s. Single plugin has higher requirements "
|
|
"than available or forcing single process: '%s'",
|
|
required, available, phase)
|
|
phases.append([phase])
|
|
else:
|
|
logger.debug("Required: %s, available: %s. Adding phase to flow: %s",
|
|
required, available, current_phase)
|
|
phases.append(current_phase)
|
|
current_phase = [phase]
|
|
if current_phase:
|
|
phases.append(current_phase)
|
|
logger.debug("Total phases: %s, Phases: %s", len(phases), phases)
|
|
return phases
|
|
|
|
# << INTERNAL PLUGIN HANDLING >> #
|
|
def _load_align(self,
|
|
aligner: str | None,
|
|
configfile: str | None,
|
|
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None,
|
|
re_feed: int,
|
|
re_align: bool,
|
|
disable_filter: bool) -> Aligner | None:
|
|
""" Set global arguments and load aligner plugin
|
|
|
|
Parameters
|
|
----------
|
|
aligner: str
|
|
The aligner plugin to load or ``None`` for no aligner
|
|
configfile: str
|
|
Optional full path to custom config file
|
|
normalize_method: str
|
|
Optional normalization method to use
|
|
re_feed: int
|
|
The number of times to adjust the image and re-feed to get an average score
|
|
re_align: bool
|
|
``True`` to obtain landmarks by passing the initially aligned face back through the
|
|
aligner.
|
|
disable_filter: bool
|
|
Disable all aligner filters regardless of config option
|
|
|
|
Returns
|
|
-------
|
|
Aligner plugin if one is specified otherwise ``None``
|
|
"""
|
|
if aligner is None or aligner.lower() == "none":
|
|
logger.debug("No aligner selected. Returning None")
|
|
return None
|
|
aligner_name = aligner.replace("-", "_").lower()
|
|
logger.debug("Loading Aligner: '%s'", aligner_name)
|
|
plugin = PluginLoader.get_aligner(aligner_name)(exclude_gpus=self._exclude_gpus,
|
|
configfile=configfile,
|
|
normalize_method=normalize_method,
|
|
re_feed=re_feed,
|
|
re_align=re_align,
|
|
disable_filter=disable_filter,
|
|
instance=self._instance)
|
|
return plugin
|
|
|
|
def _load_detect(self,
|
|
detector: str | None,
|
|
aligner: str | None,
|
|
rotation: str | None,
|
|
min_size: int,
|
|
configfile: str | None) -> Detector | None:
|
|
""" Set global arguments and load detector plugin
|
|
|
|
Parameters
|
|
----------
|
|
detector: str | None
|
|
The name of the face detection plugin to use. ``None`` for no detection
|
|
aligner: str | None
|
|
The name of the face aligner plugin to use. ``None`` for no aligner
|
|
rotation: str | None
|
|
The rotation to perform on detection. ``None`` for no rotation
|
|
min_size: int
|
|
The minimum size of detected faces to accept
|
|
configfile: str | None
|
|
Full path to a custom config file to use. ``None`` for default config
|
|
|
|
Returns
|
|
-------
|
|
:class:`~plugins.extract.detect._base.Detector` | None
|
|
The face detection plugin to use, or ``None`` if no detection to be performed
|
|
"""
|
|
if detector is None or detector.lower() == "none":
|
|
logger.debug("No detector selected. Returning None")
|
|
return None
|
|
detector_name = detector.replace("-", "_").lower()
|
|
|
|
if aligner == "external" and detector_name != "external":
|
|
logger.warning("Unsupported '%s' detector selected for 'External' aligner. Switching "
|
|
"detector to 'External'", detector_name)
|
|
detector_name = aligner
|
|
|
|
logger.debug("Loading Detector: '%s'", detector_name)
|
|
plugin = PluginLoader.get_detector(detector_name)(exclude_gpus=self._exclude_gpus,
|
|
rotation=rotation,
|
|
min_size=min_size,
|
|
configfile=configfile,
|
|
instance=self._instance)
|
|
return plugin
|
|
|
|
def _load_mask(self,
|
|
masker: str | None,
|
|
configfile: str | None) -> Masker | None:
|
|
""" Set global arguments and load masker plugin
|
|
|
|
Parameters
|
|
----------
|
|
masker: str or ``none``
|
|
The name of the masker plugin to use or ``None`` if no masker
|
|
configfile: str
|
|
Full path to custom config.ini file or ``None`` to use default
|
|
|
|
Returns
|
|
-------
|
|
:class:`~plugins.extract.mask._base.Masker` or ``None``
|
|
The masker plugin to use or ``None`` if no masker selected
|
|
"""
|
|
if masker is None or masker.lower() == "none":
|
|
logger.debug("No masker selected. Returning None")
|
|
return None
|
|
masker_name = masker.replace("-", "_").lower()
|
|
logger.debug("Loading Masker: '%s'", masker_name)
|
|
plugin = PluginLoader.get_masker(masker_name)(exclude_gpus=self._exclude_gpus,
|
|
configfile=configfile,
|
|
instance=self._instance)
|
|
return plugin
|
|
|
|
def _load_recognition(self,
|
|
recognition: str | None,
|
|
configfile: str | None) -> Identity | None:
|
|
""" Set global arguments and load recognition plugin """
|
|
if recognition is None or recognition.lower() == "none":
|
|
logger.debug("No recognition selected. Returning None")
|
|
return None
|
|
recognition_name = recognition.replace("-", "_").lower()
|
|
logger.debug("Loading Recognition: '%s'", recognition_name)
|
|
plugin = PluginLoader.get_recognition(recognition_name)(exclude_gpus=self._exclude_gpus,
|
|
configfile=configfile,
|
|
instance=self._instance)
|
|
return plugin
|
|
|
|
def _launch_plugin(self, phase: str) -> None:
|
|
""" Launch an extraction plugin """
|
|
logger.debug("Launching %s plugin", phase)
|
|
in_qname = f"extract{self._instance}_{phase}_in"
|
|
if phase == self._final_phase:
|
|
out_qname = f"extract{self._instance}_{self._final_phase}_out"
|
|
else:
|
|
next_phase = self._flow[self._flow.index(phase) + 1]
|
|
out_qname = f"extract{self._instance}_{next_phase}_in"
|
|
logger.debug("in_qname: %s, out_qname: %s", in_qname, out_qname)
|
|
kwargs = {"in_queue": self._queues[in_qname], "out_queue": self._queues[out_qname]}
|
|
|
|
plugin_type, idx = self._get_plugin_type_and_index(phase)
|
|
plugin = getattr(self, f"_{plugin_type}")
|
|
plugin = plugin[idx] if idx is not None else plugin
|
|
plugin.initialize(**kwargs)
|
|
plugin.start()
|
|
logger.debug("Launched %s plugin", phase)
|
|
|
|
def _set_extractor_batchsize(self) -> None:
|
|
"""
|
|
Sets the batch size of the requested plugins based on their vram, their
|
|
vram_per_batch_requirements and the number of plugins being loaded in the current phase.
|
|
Only adjusts if the the configured batch size requires more vram than is available. Nvidia
|
|
only.
|
|
"""
|
|
backend = get_backend()
|
|
if backend not in ("nvidia", "directml", "rocm"):
|
|
logger.debug("Not updating batchsize requirements for backend: '%s'", backend)
|
|
return
|
|
if sum(plugin.vram for plugin in self._active_plugins) == 0:
|
|
logger.debug("No plugins use VRAM. Not updating batchsize requirements.")
|
|
return
|
|
|
|
batch_required = sum(plugin.vram_per_batch * plugin.batchsize
|
|
for plugin in self._active_plugins)
|
|
gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0]
|
|
scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback)
|
|
plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling
|
|
if plugins_required + batch_required <= T.cast(int, self._vram_stats["vram_free"]):
|
|
logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, "
|
|
"vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"])
|
|
return
|
|
# Hacky split across plugins that use vram
|
|
available_vram = (T.cast(int, self._vram_stats["vram_free"])
|
|
- plugins_required) // len(gpu_plugins)
|
|
self._set_plugin_batchsize(gpu_plugins, available_vram)
|
|
|
|
def _set_plugin_batchsize(self, gpu_plugins: list[str], available_vram: float) -> None:
|
|
""" Set the batch size for the given plugin based on given available vram.
|
|
Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to
|
|
zero division error.
|
|
"""
|
|
plugins = [self._active_plugins[idx]
|
|
for idx, plugin in enumerate(self._current_phase)
|
|
if plugin in gpu_plugins]
|
|
vram_per_batch = [plugin.vram_per_batch for plugin in plugins]
|
|
ratios = [vram / sum(vram_per_batch) for vram in vram_per_batch]
|
|
requested_batchsizes = [plugin.batchsize for plugin in plugins]
|
|
batchsizes = [min(requested, max(1, int((available_vram * ratio) / plugin.vram_per_batch)))
|
|
for ratio, plugin, requested in zip(ratios, plugins, requested_batchsizes)]
|
|
remaining = available_vram - sum(batchsize * plugin.vram_per_batch
|
|
for batchsize, plugin in zip(batchsizes, plugins))
|
|
sorted_indices = [i[0] for i in sorted(enumerate(plugins),
|
|
key=lambda x: x[1].vram_per_batch, reverse=True)]
|
|
|
|
logger.debug("requested_batchsizes: %s, batchsizes: %s, remaining vram: %s",
|
|
requested_batchsizes, batchsizes, remaining)
|
|
|
|
while remaining > min(plugin.vram_per_batch
|
|
for plugin in plugins) and requested_batchsizes != batchsizes:
|
|
for idx in sorted_indices:
|
|
plugin = plugins[idx]
|
|
if plugin.vram_per_batch > remaining:
|
|
logger.debug("Not enough VRAM to increase batch size of %s. Required: %sMB, "
|
|
"Available: %sMB", plugin, plugin.vram_per_batch, remaining)
|
|
continue
|
|
if plugin.batchsize == batchsizes[idx]:
|
|
logger.debug("Threshold reached for %s. Batch size: %s",
|
|
plugin, plugin.batchsize)
|
|
continue
|
|
logger.debug("Incrementing batch size of %s to %s", plugin, batchsizes[idx] + 1)
|
|
batchsizes[idx] += 1
|
|
remaining -= plugin.vram_per_batch
|
|
logger.debug("Remaining VRAM to allocate: %sMB", remaining)
|
|
|
|
if batchsizes != requested_batchsizes:
|
|
text = ", ".join([f"{plugin.__class__.__name__}: {batchsize}"
|
|
for plugin, batchsize in zip(plugins, batchsizes)])
|
|
for plugin, batchsize in zip(plugins, batchsizes):
|
|
plugin.batchsize = batchsize
|
|
logger.info("Reset batch sizes due to available VRAM: %s", text)
|
|
|
|
def _join_threads(self):
|
|
""" Join threads for current pass """
|
|
for plugin in self._active_plugins:
|
|
plugin.join()
|
|
|
|
def _check_and_raise_error(self) -> None:
|
|
""" Check all threads for errors and raise if one occurs """
|
|
for plugin in self._active_plugins:
|
|
plugin.check_and_raise_error()
|