1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/lib/training/cache.py

516 lines
22 KiB
Python

#!/usr/bin/env python3
""" Holds the data cache for training data generators """
from __future__ import annotations
import logging
import os
import typing as T
from threading import Lock
import cv2
import numpy as np
from tqdm import tqdm
from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType
from lib.image import read_image_batch, read_image_meta_batch
from lib.utils import FaceswapError
if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict
from lib.config import ConfigValueType
logger = logging.getLogger(__name__)
_FACE_CACHES: dict[str, "_Cache"] = {}
def get_cache(side: T.Literal["a", "b"],
filenames: list[str] | None = None,
config: dict[str, ConfigValueType] | None = None,
size: int | None = None,
coverage_ratio: float | None = None) -> "_Cache":
""" Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then
create it.
Parameters
----------
side: str
`"a"` or `"b"`. The side of the model to obtain the cache for
filenames: list
The filenames of all the images. This can either be the full path or the base name. If the
full paths are passed in, they are stripped to base name for use as the cache key. Must be
passed for the first call of this function for each side. For subsequent calls this
parameter is ignored. Default: ``None``
config: dict, optional
The user selected training configuration options. Must be passed for the first call of this
function for each side. For subsequent calls this parameter is ignored. Default: ``None``
size: int, optional
The largest output size of the model. Must be passed for the first call of this function
for each side. For subsequent calls this parameter is ignored. Default: ``None``
coverage_ratio: float: optional
The coverage ratio that the model is using. Must be passed for the first call of this
function for each side. For subsequent calls this parameter is ignored. Default: ``None``
Returns
-------
:class:`_Cache`
The face meta information cache for the requested side
"""
if not _FACE_CACHES.get(side):
assert config is not None, ("config must be provided for first call to cache")
assert filenames is not None, ("filenames must be provided for first call to cache")
assert size is not None, ("size must be provided for first call to cache")
assert coverage_ratio is not None, ("coverage_ratio must be provided for first call to "
"cache")
logger.debug("Creating cache. side: %s, size: %s, coverage_ratio: %s",
side, size, coverage_ratio)
_FACE_CACHES[side] = _Cache(filenames, config, size, coverage_ratio)
return _FACE_CACHES[side]
def _check_reset(face_cache: "_Cache") -> bool:
""" Check whether a given cache needs to be reset because a face centering change has been
detected in the other cache.
Parameters
----------
face_cache: :class:`_Cache`
The cache object that is checking whether it should reset
Returns
-------
bool
``True`` if the given object should reset the cache, otherwise ``False``
"""
check_cache = next((cache for cache in _FACE_CACHES.values() if cache != face_cache), None)
retval = False if check_cache is None else check_cache.check_reset()
return retval
class _Cache():
""" A thread safe mechanism for collecting and holding face meta information (masks, "
"alignments data etc.) for multiple :class:`TrainingDataGenerator`s.
Each side may have up to 3 generators (training, preview and time-lapse). To conserve VRAM
these need to share access to the same face information for the images they are processing.
As the cache is populated at run-time, thread safe writes are required for the first epoch.
Following that, the cache is only used for reads, which is thread safe intrinsically.
It would probably be quicker to set locks on each individual face, but for code complexity
reasons, and the fact that the lock is only taken up during cache population, and it should
only be being read multiple times on save iterations, we lock the whole cache during writes.
Parameters
----------
filenames: list
The filenames of all the images. This can either be the full path or the base name. If the
full paths are passed in, they are stripped to base name for use as the cache key.
config: dict
The user selected training configuration options
size: int
The largest output size of the model
coverage_ratio: float
The coverage ratio that the model is using.
"""
def __init__(self,
filenames: list[str],
config: dict[str, ConfigValueType],
size: int,
coverage_ratio: float) -> None:
logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)",
self.__class__.__name__, len(filenames), size, coverage_ratio)
self._lock = Lock()
self._cache_info = {"cache_full": False, "has_reset": False}
self._partially_loaded: list[str] = []
self._image_count = len(filenames)
self._cache: dict[str, DetectedFace] = {}
self._aligned_landmarks: dict[str, np.ndarray] = {}
self._extract_version = 0.0
self._size = size
assert config["centering"] in T.get_args(CenteringType)
self._centering: CenteringType = T.cast(CenteringType, config["centering"])
self._config = config
self._coverage_ratio = coverage_ratio
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def cache_full(self) -> bool:
"""bool: ``True`` if the cache has been fully populated. ``False`` if there are items still
to be cached. """
if self._cache_info["cache_full"]:
return self._cache_info["cache_full"]
with self._lock:
return self._cache_info["cache_full"]
@property
def aligned_landmarks(self) -> dict[str, np.ndarray]:
""" dict: The filename as key, aligned landmarks as value. """
# Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate
# all of the aligned landmarks for the entire cache.
if not self._aligned_landmarks:
with self._lock:
# For Warp-To-Landmarks a race condition can occur where this is referenced from
# the opposite side prior to it being populated, so block on a lock.
self._aligned_landmarks = {key: face.aligned.landmarks
for key, face in self._cache.items()}
return self._aligned_landmarks
@property
def size(self) -> int:
""" int: The pixel size of the cropped aligned face """
return self._size
def check_reset(self) -> bool:
""" Check whether this cache has been reset due to a face centering change, and reset the
flag if it has.
Returns
-------
bool
``True`` if the cache has been reset because of a face centering change due to
legacy alignments, otherwise ``False``. """
retval = self._cache_info["has_reset"]
if retval:
logger.debug("Resetting 'has_reset' flag")
self._cache_info["has_reset"] = False
return retval
def get_items(self, filenames: list[str]) -> list[DetectedFace]:
""" Obtain the cached items for a list of filenames. The returned list is in the same order
as the provided filenames.
Parameters
----------
filenames: list
A list of image filenames to obtain the cached data for
Returns
-------
list
List of DetectedFace objects holding the cached metadata. The list returns in the same
order as the filenames received
"""
return [self._cache[os.path.basename(filename)] for filename in filenames]
def cache_metadata(self, filenames: list[str]) -> np.ndarray:
""" Obtain the batch with metadata for items that need caching and cache DetectedFace
objects to :attr:`_cache`.
Parameters
----------
filenames: list
List of full paths to image file names
Returns
-------
:class:`numpy.ndarray`
The batch of face images loaded from disk
"""
keys = [os.path.basename(filename) for filename in filenames]
with self._lock:
if _check_reset(self):
self._reset_cache(False)
needs_cache = [filename
for filename, key in zip(filenames, keys)
if key not in self._cache or key in self._partially_loaded]
logger.trace("Needs cache: %s", needs_cache) # type: ignore
if not needs_cache:
# Don't bother reading the metadata if no images in this batch need caching
logger.debug("All metadata already cached for: %s", keys)
return read_image_batch(filenames)
try:
batch, metadata = read_image_batch(filenames, with_metadata=True)
except ValueError as err:
if "inhomogeneous" in str(err):
raise FaceswapError(
"There was an error loading a batch of images. This is most likely due to "
"non-faceswap extracted faces in your training folder."
"\nAll training images should be Faceswap extracted faces."
"\nAll training images should be the same size."
f"\nThe files that caused this error are: {filenames}") from err
raise
if len(batch.shape) == 1:
folder = os.path.dirname(filenames[0])
details = [
f"{key} ({f'{img.shape[1]}px' if isinstance(img, np.ndarray) else type(img)})"
for key, img in zip(keys, batch)]
msg = (f"There are mismatched image sizes in the folder '{folder}'. All training "
"images for each side must have the same dimensions.\nThe batch that "
f"failed contains the following files:\n{details}.")
raise FaceswapError(msg)
# Populate items into cache
for filename in needs_cache:
key = os.path.basename(filename)
meta = metadata[filenames.index(filename)]
# Version Check
self._validate_version(meta, filename)
if self._partially_loaded: # Faces already loaded for Warp-to-landmarks
self._partially_loaded.remove(key)
detected_face = self._cache[key]
else:
detected_face = self._load_detected_face(filename, meta["alignments"])
self._prepare_masks(filename, detected_face)
self._cache[key] = detected_face
# Update the :attr:`cache_full` attribute
cache_full = not self._partially_loaded and len(self._cache) == self._image_count
if cache_full:
logger.verbose("Cache filled: '%s'", os.path.dirname(filenames[0])) # type: ignore
self._cache_info["cache_full"] = cache_full
return batch
def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None:
""" When warp to landmarks is enabled, the cache must be pre-filled, as each side needs
access to the other side's alignments.
Parameters
----------
filenames: list
The list of full paths to the images to load the metadata from
side: str
`"a"` or `"b"`. The side of the model being cached. Used for info output
"""
with self._lock:
for filename, meta in tqdm(read_image_meta_batch(filenames),
desc=f"WTL: Caching Landmarks ({side.upper()})",
total=len(filenames),
leave=False):
if "itxt" not in meta or "alignments" not in meta["itxt"]:
raise FaceswapError(f"Invalid face image found. Aborting: '{filename}'")
meta = meta["itxt"]
key = os.path.basename(filename)
# Version Check
self._validate_version(meta, filename)
detected_face = self._load_detected_face(filename, meta["alignments"])
self._cache[key] = detected_face
self._partially_loaded.append(key)
def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None:
""" Validate that there are not a mix of v1.0 extracted faces and v2.x faces.
Parameters
----------
png_meta: dict
The information held within the Faceswap PNG Header
filename: str
The full path to the file being validated
Raises
------
FaceswapError
If a version 1.0 face appears in a 2.x set or vice versa
"""
alignment_version = png_meta["source"]["alignments_version"]
if not self._extract_version:
logger.debug("Setting initial extract version: %s", alignment_version)
self._extract_version = alignment_version
if alignment_version == 1.0 and self._centering != "legacy":
self._reset_cache(True)
return
if (self._extract_version == 1.0 and alignment_version > 1.0) or (
alignment_version == 1.0 and self._extract_version > 1.0):
raise FaceswapError("Mixing legacy and full head extracted facesets is not supported. "
"The following folder contains a mix of extracted face types: "
f"'{os.path.dirname(filename)}'")
self._extract_version = min(alignment_version, self._extract_version)
def _reset_cache(self, set_flag: bool) -> None:
""" In the event that a legacy extracted face has been seen, and centering is not legacy
the cache will need to be reset for legacy centering.
Parameters
----------
set_flag: bool
``True`` if the flag should be set to indicate that the cache is being reset because of
a legacy face set/centering mismatch. ``False`` if the cache is being reset because it
has detected a reset flag from the opposite cache.
"""
if set_flag:
logger.warning("You are using legacy extracted faces but have selected '%s' centering "
"which is incompatible. Switching centering to 'legacy'",
self._centering)
self._config["centering"] = "legacy"
self._centering = "legacy"
self._cache = {}
self._cache_info["cache_full"] = False
if set_flag:
self._cache_info["has_reset"] = True
def _load_detected_face(self,
filename: str,
alignments: PNGHeaderAlignmentsDict) -> DetectedFace:
""" Load a :class:`DetectedFace` object and load its associated `aligned` property.
Parameters
----------
filename: str
The file path for the current image
alignments: dict
The alignments for a single face, extracted from a PNG header
Returns
-------
:class:`lib.align.DetectedFace`
The loaded Detected Face object
"""
detected_face = DetectedFace()
detected_face.from_png_meta(alignments)
detected_face.load_aligned(None,
size=self._size,
centering=self._centering,
coverage_ratio=self._coverage_ratio,
is_aligned=True,
is_legacy=self._extract_version == 1.0)
logger.trace("Cached aligned face for: %s", filename) # type: ignore
return detected_face
def _prepare_masks(self, filename: str, detected_face: DetectedFace) -> None:
""" Prepare the masks required from training, and compile into a single compressed array
Parameters
----------
filename: str
The file path for the current image
detected_face: :class:`lib.align.DetectedFace`
The detected face object that holds the masks
"""
masks = [(self._get_face_mask(filename, detected_face))]
for area in T.get_args(T.Literal["eye", "mouth"]):
masks.append(self._get_localized_mask(filename, detected_face, area))
detected_face.store_training_masks(masks, delete_masks=True)
logger.trace("Stored masks for filename: %s)", filename) # type: ignore
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None:
""" Obtain the training sized face mask from the :class:`DetectedFace` for the requested
mask type.
Parameters
----------
filename: str
The file path for the current image
detected_face: :class:`lib.align.DetectedFace`
The detected face object that holds the masks
Raises
------
FaceswapError
If the requested mask type is not available an error is returned along with a list
of available masks
"""
if not self._config["penalized_mask_loss"] and not self._config["learn_mask"]:
return None
if not self._config["mask_type"]:
logger.debug("No mask selected. Not validating")
return None
if self._config["mask_type"] not in detected_face.mask:
raise FaceswapError(
f"You have selected the mask type '{self._config['mask_type']}' but at least one "
"face does not contain the selected mask.\n"
f"The face that failed was: '{filename}'\n"
f"The masks that exist for this face are: {list(detected_face.mask)}")
mask = detected_face.mask[str(self._config["mask_type"])]
assert isinstance(self._config["mask_dilation"], float)
assert isinstance(self._config["mask_blur_kernel"], int)
assert isinstance(self._config["mask_threshold"], int)
mask.set_dilation(self._config["mask_dilation"])
mask.set_blur_and_threshold(blur_kernel=self._config["mask_blur_kernel"],
threshold=self._config["mask_threshold"])
pose = detected_face.aligned.pose
mask.set_sub_crop(pose.offset[mask.stored_centering],
pose.offset[self._centering],
self._centering,
self._coverage_ratio)
face_mask = mask.mask
if self._size != face_mask.shape[0]:
interpolator = cv2.INTER_CUBIC if mask.stored_size < self._size else cv2.INTER_AREA
face_mask = cv2.resize(face_mask,
(self._size, self._size),
interpolation=interpolator)[..., None]
logger.trace("Obtained face mask for: %s %s", filename, face_mask.shape) # type: ignore
return face_mask
def _get_localized_mask(self,
filename: str,
detected_face: DetectedFace,
area: T.Literal["eye", "mouth"]) -> np.ndarray | None:
""" Obtain a localized mask for the given area if it is required for training.
Parameters
----------
filename: str
The file path for the current image
detected_face: :class:`lib.align.DetectedFace`
The detected face object that holds the masks
area: str
`"eye"` or `"mouth"`. The area of the face to obtain the mask for
"""
multiplier = self._config[f"{area}_multiplier"]
assert isinstance(multiplier, int)
if not self._config["penalized_mask_loss"] or multiplier <= 1:
return None
mask = detected_face.get_landmark_mask(area, self._size // 16, 2.5)
logger.trace("Caching localized '%s' mask for: %s %s", # type: ignore
area, filename, mask.shape)
return mask
class RingBuffer(): # pylint: disable=too-few-public-methods
""" Rolling buffer for holding training/preview batches
Parameters
----------
batch_size: int
The batch size to create the buffer for
image_shape: tuple
The height/width/channels shape of a single image in the batch
buffer_size: int, optional
The number of arrays to hold in the rolling buffer. Default: `2`
dtype: str, optional
The datatype to create the buffer as. Default: `"uint8"`
"""
def __init__(self,
batch_size: int,
image_shape: tuple[int, int, int],
buffer_size: int = 2,
dtype: str = "uint8") -> None:
logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, "
"dtype: %s", self.__class__.__name__, batch_size, image_shape, buffer_size,
dtype)
self._max_index = buffer_size - 1
self._index = 0
self._buffer = [np.empty((batch_size, *image_shape), dtype=dtype)
for _ in range(buffer_size)]
logger.debug("Initialized: %s", self.__class__.__name__) # type: ignore
def __call__(self) -> np.ndarray:
""" Obtain the next array from the ring buffer
Returns
-------
:class:`np.ndarray`
A pre-allocated numpy array from the buffer
"""
retval = self._buffer[self._index]
self._index += 1 if self._index < self._max_index else -self._max_index
return retval