1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/training/generator.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

726 lines
32 KiB
Python

#!/usr/bin/env python3
""" Handles Data Augmentation for feeding Faceswap Models """
from __future__ import annotations
import logging
import os
import typing as T
from concurrent import futures
from random import shuffle, choice
import cv2
import numpy as np
import numexpr as ne
from lib.align import AlignedFace, DetectedFace
from lib.align.aligned_face import CenteringType
from lib.image import read_image_batch
from lib.multithreading import BackgroundGenerator
from lib.utils import FaceswapError
from . import ImageAugmentation
from .cache import get_cache, RingBuffer
if T.TYPE_CHECKING:
from collections.abc import Generator
from lib.config import ConfigValueType
from plugins.train.model._base import ModelBase
from .cache import _Cache
logger = logging.getLogger(__name__)
BatchType = tuple[np.ndarray, list[np.ndarray]]
class DataGenerator():
""" Parent class for Training and Preview Data Generators.
This class is called from :mod:`plugins.train.trainer._base` and launches a background
iterator that compiles augmented data, target data and sample data.
Parameters
----------
model: :class:`~plugins.train.model.ModelBase`
The model that this data generator is feeding
config: dict
The configuration `dict` generated from :file:`config.train.ini` containing the trainer
plugin configuration options.
side: {'a' or 'b'}
The side of the model that this iterator is for.
images: list
A list of image paths that will be used to compile the final augmented data from.
batch_size: int
The batch size for this iterator. Images will be returned in :class:`numpy.ndarray`
objects of this size from the iterator.
"""
def __init__(self,
config: dict[str, ConfigValueType],
model: ModelBase,
side: T.Literal["a", "b"],
images: list[str],
batch_size: int) -> None:
logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore
"batch_size: %s, config: %s)", self.__class__.__name__, model.name, side,
len(images), batch_size, config)
self._config = config
self._side = side
self._images = images
self._batch_size = batch_size
self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes)
self._output_sizes = self._get_output_sizes(model)
self._model_input_size = max(img[1] for img in model.input_shapes)
self._coverage_ratio = model.coverage_ratio
self._color_order = model.color_order.lower()
self._use_mask = self._config["mask_type"] and (self._config["penalized_mask_loss"] or
self._config["learn_mask"])
self._validate_samples()
self._buffer = RingBuffer(batch_size,
(self._process_size, self._process_size, self._total_channels),
dtype="uint8")
self._face_cache: _Cache = get_cache(side,
filenames=images,
config=self._config,
size=self._process_size,
coverage_ratio=self._coverage_ratio)
logger.debug("Initialized %s", self.__class__.__name__)
@property
def _total_channels(self) -> int:
"""int: The total number of channels, including mask channels that the target image
should hold. """
channels = 3
if self._config["mask_type"] and (self._config["learn_mask"] or
self._config["penalized_mask_loss"]):
channels += 1
mults = [area for area in ["eye", "mouth"]
if T.cast(int, self._config[f"{area}_multiplier"]) > 1]
if self._config["penalized_mask_loss"] and mults:
channels += len(mults)
return channels
def _get_output_sizes(self, model: ModelBase) -> list[int]:
""" Obtain the size of each output tensor for the model.
Parameters
----------
model: :class:`~plugins.train.model.ModelBase`
The model that this data generator is feeding
Returns
-------
list
A list of integers for the model output size for the current side
"""
out_shapes = model.output_shapes
split = len(out_shapes) // 2
side_out = out_shapes[:split] if self._side == "a" else out_shapes[split:]
retval = [shape[1] for shape in side_out if shape[-1] != 1]
logger.debug("side: %s, model output shapes: %s, output sizes: %s",
self._side, model.output_shapes, retval)
return retval
def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]:
""" A Background iterator to return augmented images, samples and targets.
The exit point from this class and the sole attribute that should be referenced. Called
from :mod:`plugins.train.trainer._base`. Returns an iterator that yields images for
training, preview and time-lapses.
Parameters
----------
do_shuffle: bool, optional
Whether data should be shuffled prior to loading from disk. If true, each time the full
list of filenames are processed, the data will be reshuffled to make sure they are not
returned in the same order. Default: ``True``
Yields
------
feed: list
4-dimensional array of faces to feed the training the model (:attr:`x` parameter for
:func:`keras.models.model.train_on_batch`.). The array returned is in the format
(`batch size`, `height`, `width`, `channels`).
targets: list
List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each
output of the model. The format of these arrays will be (`batch size`, `height`,
`width`, `x`). This is the :attr:`y` parameter for
:func:`keras.models.model.train_on_batch`. The number of channels here will vary.
The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent
channels are area masks (e.g. eye/mouth masks)
"""
logger.debug("do_shuffle: %s", do_shuffle)
args = (do_shuffle, )
batcher = BackgroundGenerator(self._minibatch, args=args)
return batcher.iterator()
# << INTERNAL METHODS >> #
def _validate_samples(self) -> None:
""" Ensures that the total number of images within :attr:`images` is greater or equal to
the selected :attr:`batch_size`.
Raises
------
:class:`FaceswapError`
If the number of images loaded is smaller than the selected batch size
"""
length = len(self._images)
msg = ("Number of images is lower than batch-size (Note that too few images may lead to "
f"bad training). # images: {length}, batch-size: {self._batch_size}")
try:
assert length >= self._batch_size, msg
except AssertionError as err:
msg += ("\nYou should increase the number of images in your training set or lower "
"your batch-size.")
raise FaceswapError(msg) from err
def _minibatch(self, do_shuffle: bool) -> Generator[BatchType, None, None]:
""" A generator function that yields the augmented, target and sample images for the
current batch on the current side.
Parameters
----------
do_shuffle: bool, optional
Whether data should be shuffled prior to loading from disk. If true, each time the full
list of filenames are processed, the data will be reshuffled to make sure they are not
returned in the same order. Default: ``True``
Yields
------
feed: list
4-dimensional array of faces to feed the training the model (:attr:`x` parameter for
:func:`keras.models.model.train_on_batch`.). The array returned is in the format
(`batch size`, `height`, `width`, `channels`).
targets: list
List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each
output of the model. The format of these arrays will be (`batch size`, `height`,
`width`, `x`). This is the :attr:`y` parameter for
:func:`keras.models.model.train_on_batch`. The number of channels here will vary.
The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent
channels are area masks (e.g. eye/mouth masks)
"""
logger.debug("Loading minibatch generator: (image_count: %s, do_shuffle: %s)",
len(self._images), do_shuffle)
def _img_iter(imgs):
""" Infinite iterator for recursing through image list and reshuffling at each epoch"""
while True:
if do_shuffle:
shuffle(imgs)
for img in imgs:
yield img
img_iter = _img_iter(self._images[:])
while True:
img_paths = [next(img_iter) # pylint:disable=stop-iteration-return
for _ in range(self._batch_size)]
retval = self._process_batch(img_paths)
yield retval
def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]:
""" Obtain the raw face images with associated :class:`DetectedFace` objects for this
batch.
If this is the first time a face has been loaded, then it's meta data is extracted
from the png header and added to :attr:`_face_cache`.
Parameters
----------
filenames: list
List of full paths to image file names
Returns
-------
raw_faces: :class:`numpy.ndarray`
The full sized batch of training images for the given filenames
list
Batch of :class:`~lib.align.DetectedFace` objects for the given filename including the
aligned face objects for the model output size
"""
if not self._face_cache.cache_full:
raw_faces = self._face_cache.cache_metadata(filenames)
else:
raw_faces = read_image_batch(filenames)
detected_faces = self._face_cache.get_items(filenames)
logger.trace("filenames: %s, raw_faces: '%s', detected_faces: %s", # type: ignore
filenames, raw_faces.shape, len(detected_faces))
return raw_faces, detected_faces
def _crop_to_coverage(self,
filenames: list[str],
images: np.ndarray,
detected_faces: list[DetectedFace],
batch: np.ndarray) -> None:
""" Crops the training image out of the full extract image based on the centering and
coveage used in the user's configuration settings.
If legacy extract images are being used then this just returns the extracted batch with
their corresponding landmarks.
Uses thread pool execution for about a 33% speed increase @ 64 batch size
Parameters
----------
filenames: list
The list of filenames that correspond to this batch
images: :class:`numpy.ndarray`
The batch of faces that have been loaded from disk
detected_faces: list
The list of :class:`lib.align.DetectedFace` items corresponding to the batch
batch: :class:`np.ndarray`
The pre-allocated array to hold this batch
"""
logger.trace("Cropping training images info: (filenames: %s, side: '%s')", # type: ignore
filenames, self._side)
with futures.ThreadPoolExecutor() as executor:
proc = {executor.submit(face.aligned.extract_face, img): idx
for idx, (face, img) in enumerate(zip(detected_faces, images))}
for future in futures.as_completed(proc):
batch[proc[future], ..., :3] = future.result()
def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None:
""" Applies the masks to the 4th channel of the batch.
If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1
then these masks are applied to the final channels of the batch respectively.
If masks are not being used then this function returns having done nothing
Parameters
----------
detected_face: list
The list of :class:`~lib.align.DetectedFace` objects corresponding to the batch
batch: :class:`numpy.ndarray`
The preallocated array to apply masks to
side: str
'"a"' or '"b"' the side that is being processed
"""
if not self._use_mask:
return
masks = np.array([face.get_training_masks() for face in detected_faces])
batch[..., 3:] = masks
logger.trace("side: %s, masks: %s, batch: %s", # type: ignore
self._side, masks.shape, batch.shape)
def _process_batch(self, filenames: list[str]) -> BatchType:
""" Prepares data for feeding through subclassed methods.
If this is the first time a face has been loaded, then it's meta data is extracted from the
png header and added to :attr:`_face_cache`
Parameters
----------
filenames: list
List of full paths to image file names for a single batch
Returns
-------
:class:`numpy.ndarray`
4-dimensional array of faces to feed the training the model.
list
List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary.
The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent
channels are area masks (e.g. eye/mouth masks)
"""
raw_faces, detected_faces = self._get_images_with_meta(filenames)
batch = self._buffer()
self._crop_to_coverage(filenames, raw_faces, detected_faces, batch)
self._apply_mask(detected_faces, batch)
feed, targets = self.process_batch(filenames, raw_faces, detected_faces, batch)
logger.trace("Processed %s batch side %s. (filenames: %s, feed: %s, " # type: ignore
"targets: %s)", self.__class__.__name__, self._side, filenames,
feed.shape, [t.shape for t in targets])
return feed, targets
def process_batch(self,
filenames: list[str],
images: np.ndarray,
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Override for processing the batch for the current generator.
Parameters
----------
filenames: list
List of full paths to image file names for a single batch
images: :class:`numpy.ndarray`
The batch of faces corresponding to the filenames
detected_faces: list
List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for
the current batch
batch: :class:`numpy.ndarray`
The pre-allocated batch with images and masks populated for the selected coverage and
centering
Returns
-------
list
4-dimensional array of faces to feed the training the model.
list
List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary.
The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent
channels are area masks (e.g. eye/mouth masks)
"""
raise NotImplementedError()
def _set_color_order(self, batch) -> None:
""" Set the color order correctly for the model's input type.
batch: :class:`numpy.ndarray`
The pre-allocated batch with images in the first 3 channels in BGR order
"""
if self._color_order == "rgb":
batch[..., :3] = batch[..., [2, 1, 0]]
def _to_float32(self, in_array: np.ndarray) -> np.ndarray:
""" Cast an UINT8 array in 0-255 range to float32 in 0.0-1.0 range.
in_array: :class:`numpy.ndarray`
The input uint8 array
"""
return ne.evaluate("x / c",
local_dict={"x": in_array, "c": np.float32(255)},
casting="unsafe")
class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-methods
""" A Training Data Generator for compiling data for feeding to a model.
This class is called from :mod:`plugins.train.trainer._base` and launches a background
iterator that compiles augmented data, target data and sample data.
Parameters
----------
model: :class:`~plugins.train.model.ModelBase`
The model that this data generator is feeding
config: dict
The configuration `dict` generated from :file:`config.train.ini` containing the trainer
plugin configuration options.
side: {'a' or 'b'}
The side of the model that this iterator is for.
images: list
A list of image paths that will be used to compile the final augmented data from.
batch_size: int
The batch size for this iterator. Images will be returned in :class:`numpy.ndarray`
objects of this size from the iterator.
"""
def __init__(self,
config: dict[str, ConfigValueType],
model: ModelBase,
side: T.Literal["a", "b"],
images: list[str],
batch_size: int) -> None:
super().__init__(config, model, side, images, batch_size)
self._augment_color = not model.command_line_arguments.no_augment_color
self._no_flip = model.command_line_arguments.no_flip
self._no_warp = model.command_line_arguments.no_warp
self._warp_to_landmarks = (not self._no_warp
and model.command_line_arguments.warp_to_landmarks)
if self._warp_to_landmarks:
self._face_cache.pre_fill(images, side)
self._processing = ImageAugmentation(batch_size,
self._process_size,
self._config)
self._nearest_landmarks: dict[str, tuple[str, ...]] = {}
logger.debug("Initialized %s", self.__class__.__name__)
def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]:
""" Compile target images, with masks, for the model output sizes.
Parameters
----------
batch: :class:`numpy.ndarray`
This should be a 4-dimensional array of training images in the format (`batch size`,
`height`, `width`, `channels`). Targets should be requested after performing image
transformations but prior to performing warps. The 4th channel should be the mask.
Any channels above the 4th should be any additional area masks (e.g. eye/mouth) that
are required.
Returns
-------
list
List of 4-dimensional target images, at all model output sizes, with masks compiled
into channels 4+ for each output size
"""
logger.trace("Compiling targets: batch shape: %s", batch.shape) # type: ignore
if len(self._output_sizes) == 1 and self._output_sizes[0] == self._process_size:
# Rolling buffer here makes next to no difference, so just create array on the fly
retval = [self._to_float32(batch)]
else:
retval = [self._to_float32(np.array([cv2.resize(image, (size, size), cv2.INTER_AREA)
for image in batch]))
for size in self._output_sizes]
logger.trace("Processed targets: %s", [t.shape for t in retval]) # type: ignore
return retval
def process_batch(self,
filenames: list[str],
images: np.ndarray,
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Performs the augmentation and compiles target images and samples.
Parameters
----------
filenames: list
List of full paths to image file names for a single batch
images: :class:`numpy.ndarray`
The batch of faces corresponding to the filenames
detected_faces: list
List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for
the current batch
batch: :class:`numpy.ndarray`
The pre-allocated batch with images and masks populated for the selected coverage and
centering
Returns
-------
feed: :class:`numpy.ndarray`
4-dimensional array of faces to feed the training the model (:attr:`x` parameter for
:func:`keras.models.model.train_on_batch`.). The array returned is in the format
(`batch size`, `height`, `width`, `channels`).
targets: list
List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each
output of the model. The format of these arrays will be (`batch size`, `height`,
`width`, `x`). This is the :attr:`y` parameter for
:func:`keras.models.model.train_on_batch`. The number of channels here will vary.
The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent
channels are area masks (e.g. eye/mouth masks)
"""
logger.trace("Process training: (side: '%s', filenames: '%s', images: %s, " # type:ignore
"batch: %s, detected_faces: %s)", self._side, filenames, images.shape,
batch.shape, len(detected_faces))
# Color Augmentation of the image only
if self._augment_color:
batch[..., :3] = self._processing.color_adjust(batch[..., :3])
# Random Transform and flip
self._processing.transform(batch)
if not self._no_flip:
self._processing.random_flip(batch)
# Switch color order for RGB models
self._set_color_order(batch)
# Get Targets
targets = self._create_targets(batch)
# TODO Look at potential for applying mask on input
# Random Warp
if self._warp_to_landmarks:
landmarks = np.array([face.aligned.landmarks for face in detected_faces])
batch_dst_pts = self._get_closest_match(filenames, landmarks)
warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts}
else:
warp_kwargs = {}
warped = batch[..., :3] if self._no_warp else self._processing.warp(
batch[..., :3],
self._warp_to_landmarks,
**warp_kwargs)
if self._model_input_size != self._process_size:
feed = self._to_float32(np.array([cv2.resize(image,
(self._model_input_size,
self._model_input_size),
cv2.INTER_AREA)
for image in warped]))
else:
feed = self._to_float32(warped)
return feed, targets
def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray:
""" Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest
matched 68 point landmarks from the opposite training set.
Parameters
----------
filenames: list
Filenames for current batch
batch_src_points: :class:`np.ndarray`
The source landmarks for the current batch
Returns
-------
:class:`np.ndarray`
Randomly selected closest matches from the other side's landmarks
"""
logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore
"src_points: '%s')", filenames, batch_src_points)
lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b"
other_cache = get_cache(lm_side)
landmarks = other_cache.aligned_landmarks
try:
closest_matches = [self._nearest_landmarks[os.path.basename(filename)]
for filename in filenames]
except KeyError:
# Resize mismatched training image size landmarks
sizes = {side: cache.size for side, cache in zip((self._side, lm_side),
(self._face_cache, other_cache))}
if len(set(sizes.values())) > 1:
scale = sizes[self._side] / sizes[lm_side]
landmarks = {key: lms * scale for key, lms in landmarks.items()}
closest_matches = self._cache_closest_matches(filenames, batch_src_points, landmarks)
batch_dst_points = np.array([landmarks[choice(fname)] for fname in closest_matches])
logger.trace("Returning: (batch_dst_points: %s)", batch_dst_points.shape) # type: ignore
return batch_dst_points
def _cache_closest_matches(self,
filenames: list[str],
batch_src_points: np.ndarray,
landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]:
""" Cache the nearest landmarks for this batch
Parameters
----------
filenames: list
Filenames for current batch
batch_src_points: :class:`np.ndarray`
The source landmarks for the current batch
landmarks: dict
The destination landmarks with associated filenames
"""
logger.trace("Caching closest matches") # type:ignore
dst_landmarks = list(landmarks.items())
dst_points = np.array([lm[1] for lm in dst_landmarks])
batch_closest_matches: list[tuple[str, ...]] = []
for filename, src_points in zip(filenames, batch_src_points):
closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10]
closest_matches = tuple(dst_landmarks[i][0] for i in closest)
self._nearest_landmarks[os.path.basename(filename)] = closest_matches
batch_closest_matches.append(closest_matches)
logger.trace("Cached closest matches") # type:ignore
return batch_closest_matches
class PreviewDataGenerator(DataGenerator):
""" Generator for compiling images for generating previews.
This class is called from :mod:`plugins.train.trainer._base` and launches a background
iterator that compiles sample preview data for feeding the model's predict function and for
display.
Parameters
----------
model: :class:`~plugins.train.model.ModelBase`
The model that this data generator is feeding
config: dict
The configuration `dict` generated from :file:`config.train.ini` containing the trainer
plugin configuration options.
side: {'a' or 'b'}
The side of the model that this iterator is for.
images: list
A list of image paths that will be used to compile the final images.
batch_size: int
The batch size for this iterator. Images will be returned in :class:`numpy.ndarray`
objects of this size from the iterator.
"""
def _create_samples(self,
images: np.ndarray,
detected_faces: list[DetectedFace]) -> list[np.ndarray]:
""" Compile the 'sample' images. These are the 100% coverage images which hold the model
output in the preview window.
Parameters
----------
images: :class:`numpy.ndarray`
The original batch of images as loaded from disk.
detected_faces: list
List of :class:`~lib.align.DetectedFace` for the current batch
Returns
-------
list
List of 4-dimensional target images, at final model output size
"""
logger.trace("Compiling samples: images shape: %s, detected_faces: %s ", # type: ignore
images.shape, len(detected_faces))
output_size = self._output_sizes[-1]
full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2))
assert self._config["centering"] in T.get_args(CenteringType)
retval = np.empty((full_size, full_size, 3), dtype="float32")
retval = self._to_float32(np.array([
AlignedFace(face.landmarks_xy,
image=images[idx],
centering=T.cast(CenteringType,
self._config["centering"]),
size=full_size,
dtype="uint8",
is_aligned=True).face
for idx, face in enumerate(detected_faces)]))
logger.trace("Processed samples: %s", retval.shape) # type: ignore
return [retval]
def process_batch(self,
filenames: list[str],
images: np.ndarray,
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Creates the full size preview images and the sub-cropped images for feeding the model's
predict function.
Parameters
----------
filenames: list
List of full paths to image file names for a single batch
images: :class:`numpy.ndarray`
The batch of faces corresponding to the filenames
detected_faces: list
List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for
the current batch
batch: :class:`numpy.ndarray`
The pre-allocated batch with images and masks populated for the selected coverage and
centering
Returns
-------
feed: :class:`numpy.ndarray`
List of 4-dimensional :class:`numpy.ndarray` objects at model output size for feeding
the model's predict function. The first 3 channels are (rgb/bgr). The 4th channel is
the face mask.
samples: list
4-dimensional array containing the 100% coverage images at the model's centering for
for generating previews. The array returned is in the format
(`batch size`, `height`, `width`, `channels`).
"""
logger.trace("Process preview: (side: '%s', filenames: '%s', images: %s, " # type:ignore
"batch: %s, detected_faces: %s)", self._side, filenames, images.shape,
batch.shape, len(detected_faces))
# Switch color order for RGB models
self._set_color_order(batch)
self._set_color_order(images)
if not self._use_mask:
mask = np.zeros_like(batch[..., 0])[..., None] + 255
batch = np.concatenate([batch, mask], axis=-1)
feed = self._to_float32(batch[..., :4]) # Don't resize here: we want masks at output res.
# If user sets model input size as larger than output size, the preview will error, so
# resize in these rare instances
out_size = max(self._output_sizes)
if self._process_size > out_size:
feed = np.array([cv2.resize(img, (out_size, out_size), interpolation=cv2.INTER_AREA)
for img in feed])
samples = self._create_samples(images, detected_faces)
return feed, samples