1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/model/losses/loss.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

668 lines
25 KiB
Python

#!/usr/bin/env python3
""" Custom Loss Functions for faceswap.py """
from __future__ import annotations
import logging
import typing as T
import numpy as np
import tensorflow as tf
# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
from tensorflow.keras import backend as K # pylint:disable=import-error
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__)
class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
""" Focal Frequencey Loss Function.
A channels last implementation.
Notes
-----
There is a bug in this implementation that will do an incorrect FFT if
:attr:`patch_factor` > ``1``, which means incorrect loss will be returned, so keep
patch factor at 1.
Parameters
----------
alpha: float, Optional
Scaling factor of the spectrum weight matrix for flexibility. Default: ``1.0``
patch_factor: int, Optional
Factor to crop image patches for patch-based focal frequency loss.
Default: ``1``
ave_spectrum: bool, Optional
``True`` to use minibatch average spectrum otherwise ``False``. Default: ``False``
log_matrix: bool, Optional
``True`` to adjust the spectrum weight matrix by logarithm otherwise ``False``.
Default: ``False``
batch_matrix: bool, Optional
``True`` to calculate the spectrum weight matrix using batch-based statistics otherwise
``False``. Default: ``False``
References
----------
https://arxiv.org/pdf/2012.12821.pdf
https://github.com/EndlessSora/focal-frequency-loss
"""
def __init__(self,
alpha: float = 1.0,
patch_factor: int = 1,
ave_spectrum: bool = False,
log_matrix: bool = False,
batch_matrix: bool = False) -> None:
self._alpha = alpha
# TODO Fix bug where FFT will be incorrect if patch_factor > 1
self._patch_factor = patch_factor
self._ave_spectrum = ave_spectrum
self._log_matrix = log_matrix
self._batch_matrix = batch_matrix
self._dims: tuple[int, int] = (0, 0)
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
""" Crop the incoming batch of images into patches as defined by :attr:`_patch_factor.
Parameters
----------
inputs: :class:`tf.Tensor`
A batch of images to be converted into patches
Returns
-------
:class`tf.Tensor``
The incoming batch converted into patches
"""
rows, cols = self._dims
patch_list = []
patch_rows = cols // self._patch_factor
patch_cols = rows // self._patch_factor
for i in range(self._patch_factor):
for j in range(self._patch_factor):
row_from = i * patch_rows
row_to = (i + 1) * patch_rows
col_from = j * patch_cols
col_to = (j + 1) * patch_cols
patch_list.append(inputs[:, row_from: row_to, col_from: col_to, :])
retval = K.stack(patch_list, axis=1)
return retval
def _tensor_to_frequency_spectrum(self, patch: tf.Tensor) -> tf.Tensor:
""" Perform FFT to create the orthonomalized DFT frequencies.
Parameters
----------
inputs: :class:`tf.Tensor`
The incoming batch of patches to convert to the frequency spectrum
Returns
-------
:class:`tf.Tensor`
The DFT frequencies split into real and imaginary numbers as float32
"""
# TODO fix this for when self._patch_factor != 1.
rows, cols = self._dims
patch = K.permute_dimensions(patch, (0, 1, 4, 2, 3)) # move channels to first
patch = patch / np.sqrt(rows * cols) # Orthonormalization
patch = K.cast(patch, "complex64")
freq = tf.signal.fft2d(patch)[..., None]
freq = K.concatenate([tf.math.real(freq), tf.math.imag(freq)], axis=-1)
freq = K.cast(freq, "float32")
freq = K.permute_dimensions(freq, (0, 1, 3, 4, 2, 5)) # channels to last
return freq
def _get_weight_matrix(self, freq_true: tf.Tensor, freq_pred: tf.Tensor) -> tf.Tensor:
""" Calculate a continuous, dynamic weight matrix based on current Euclidean distance.
Parameters
----------
freq_true: :class:`tf.Tensor`
The real and imaginary DFT frequencies for the true batch of images
freq_pred: :class:`tf.Tensor`
The real and imaginary DFT frequencies for the predicted batch of images
Returns
-------
:class:`tf.Tensor`
The weights matrix for prioritizing hard frequencies
"""
weights = K.square(freq_pred - freq_true)
weights = K.sqrt(weights[..., 0] + weights[..., 1])
weights = K.pow(weights, self._alpha)
if self._log_matrix: # adjust the spectrum weight matrix by logarithm
weights = K.log(weights + 1.0)
if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics
weights = weights / K.max(weights)
else:
weights = weights / K.max(K.max(weights, axis=-2), axis=-2)[..., None, None, :]
weights = K.switch(tf.math.is_nan(weights), K.zeros_like(weights), weights)
weights = K.clip(weights, min_value=0.0, max_value=1.0)
return weights
@classmethod
def _calculate_loss(cls,
freq_true: tf.Tensor,
freq_pred: tf.Tensor,
weight_matrix: tf.Tensor) -> tf.Tensor:
""" Perform the loss calculation on the DFT spectrum applying the weights matrix.
Parameters
----------
freq_true: :class:`tf.Tensor`
The real and imaginary DFT frequencies for the true batch of images
freq_pred: :class:`tf.Tensor`
The real and imaginary DFT frequencies for the predicted batch of images
Returns
:class:`tf.Tensor`
The final loss matrix
"""
tmp = K.square(freq_pred - freq_true) # freq distance using squared Euclidean distance
freq_distance = tmp[..., 0] + tmp[..., 1]
loss = weight_matrix * freq_distance # dynamic spectrum weighting (Hadamard product)
return loss
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Call the Focal Frequency Loss Function.
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth batch of images
y_pred: :class:`tf.Tensor`
The predicted batch of images
Returns
-------
:class:`tf.Tensor`
The loss for this batch of images
"""
if not all(self._dims):
rows, cols = K.int_shape(y_true)[1:3]
assert cols % self._patch_factor == 0 and rows % self._patch_factor == 0, (
"Patch factor must be a divisor of the image height and width")
self._dims = (rows, cols)
patches_true = self._get_patches(y_true)
patches_pred = self._get_patches(y_pred)
freq_true = self._tensor_to_frequency_spectrum(patches_true)
freq_pred = self._tensor_to_frequency_spectrum(patches_pred)
if self._ave_spectrum: # whether to use minibatch average spectrum
freq_true = K.mean(freq_true, axis=0, keepdims=True)
freq_pred = K.mean(freq_pred, axis=0, keepdims=True)
weight_matrix = self._get_weight_matrix(freq_true, freq_pred)
return self._calculate_loss(freq_true, freq_pred, weight_matrix)
class GeneralizedLoss(): # pylint:disable=too-few-public-methods
""" Generalized function used to return a large variety of mathematical loss functions.
The primary benefit is a smooth, differentiable version of L1 loss.
References
----------
Barron, J. A General and Adaptive Robust Loss Function - https://arxiv.org/pdf/1701.03077.pdf
Example
-------
>>> a=1.0, x>>c , c=1.0/255.0 # will give a smoothly differentiable version of L1 / MAE loss
>>> a=1.999999 (limit as a->2), beta=1.0/255.0 # will give L2 / RMSE loss
Parameters
----------
alpha: float, optional
Penalty factor. Larger number give larger weight to large deviations. Default: `1.0`
beta: float, optional
Scale factor used to adjust to the input scale (i.e. inputs of mean `1e-4` or `256`).
Default: `1.0/255.0`
"""
def __init__(self, alpha: float = 1.0, beta: float = 1.0/255.0) -> None:
self._alpha = alpha
self._beta = beta
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Call the Generalized Loss Function
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth value
y_pred: :class:`tf.Tensor`
The predicted value
Returns
-------
:class:`tf.Tensor`
The loss value from the results of function(y_pred - y_true)
"""
diff = y_pred - y_true
second = (K.pow(K.pow(diff/self._beta, 2.) / K.abs(2. - self._alpha) + 1.,
(self._alpha / 2.)) - 1.)
loss = (K.abs(2. - self._alpha)/self._alpha) * second
loss = K.mean(loss, axis=-1) * self._beta
return loss
class GradientLoss(): # pylint:disable=too-few-public-methods
""" Gradient Loss Function.
Calculates the first and second order gradient difference between pixels of an image in the x
and y dimensions. These gradients are then compared between the ground truth and the predicted
image and the difference is taken. When used as a loss, its minimization will result in
predicted images approaching the same level of sharpness / blurriness as the ground truth.
References
----------
TV+TV2 Regularization with Non-Convex Sparseness-Inducing Penalty for Image Restoration,
Chengwu Lu & Hua Huang, 2014 - http://downloads.hindawi.com/journals/mpe/2014/790547.pdf
"""
def __init__(self) -> None:
self.generalized_loss = GeneralizedLoss(alpha=1.9999)
self._tv_weight = 1.0
self._tv2_weight = 1.0
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Call the gradient loss function.
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth value
y_pred: :class:`tf.Tensor`
The predicted value
Returns
-------
:class:`tf.Tensor`
The loss value
"""
loss = 0.0
loss += self._tv_weight * (self.generalized_loss(self._diff_x(y_true),
self._diff_x(y_pred)) +
self.generalized_loss(self._diff_y(y_true),
self._diff_y(y_pred)))
loss += self._tv2_weight * (self.generalized_loss(self._diff_xx(y_true),
self._diff_xx(y_pred)) +
self.generalized_loss(self._diff_yy(y_true),
self._diff_yy(y_pred)) +
self.generalized_loss(self._diff_xy(y_true),
self._diff_xy(y_pred)) * 2.)
loss = loss / (self._tv_weight + self._tv2_weight)
# TODO simplify to use MSE instead
return loss
@classmethod
def _diff_x(cls, img: tf.Tensor) -> tf.Tensor:
""" X Difference """
x_left = img[:, :, 1:2, :] - img[:, :, 0:1, :]
x_inner = img[:, :, 2:, :] - img[:, :, :-2, :]
x_right = img[:, :, -1:, :] - img[:, :, -2:-1, :]
x_out = K.concatenate([x_left, x_inner, x_right], axis=2)
return x_out * 0.5
@classmethod
def _diff_y(cls, img: tf.Tensor) -> tf.Tensor:
""" Y Difference """
y_top = img[:, 1:2, :, :] - img[:, 0:1, :, :]
y_inner = img[:, 2:, :, :] - img[:, :-2, :, :]
y_bot = img[:, -1:, :, :] - img[:, -2:-1, :, :]
y_out = K.concatenate([y_top, y_inner, y_bot], axis=1)
return y_out * 0.5
@classmethod
def _diff_xx(cls, img: tf.Tensor) -> tf.Tensor:
""" X-X Difference """
x_left = img[:, :, 1:2, :] + img[:, :, 0:1, :]
x_inner = img[:, :, 2:, :] + img[:, :, :-2, :]
x_right = img[:, :, -1:, :] + img[:, :, -2:-1, :]
x_out = K.concatenate([x_left, x_inner, x_right], axis=2)
return x_out - 2.0 * img
@classmethod
def _diff_yy(cls, img: tf.Tensor) -> tf.Tensor:
""" Y-Y Difference """
y_top = img[:, 1:2, :, :] + img[:, 0:1, :, :]
y_inner = img[:, 2:, :, :] + img[:, :-2, :, :]
y_bot = img[:, -1:, :, :] + img[:, -2:-1, :, :]
y_out = K.concatenate([y_top, y_inner, y_bot], axis=1)
return y_out - 2.0 * img
@classmethod
def _diff_xy(cls, img: tf.Tensor) -> tf.Tensor:
""" X-Y Difference """
# xout1
# Left
top = img[:, 1:2, 1:2, :] + img[:, 0:1, 0:1, :]
inner = img[:, 2:, 1:2, :] + img[:, :-2, 0:1, :]
bottom = img[:, -1:, 1:2, :] + img[:, -2:-1, 0:1, :]
xy_left = K.concatenate([top, inner, bottom], axis=1)
# Mid
top = img[:, 1:2, 2:, :] + img[:, 0:1, :-2, :]
mid = img[:, 2:, 2:, :] + img[:, :-2, :-2, :]
bottom = img[:, -1:, 2:, :] + img[:, -2:-1, :-2, :]
xy_mid = K.concatenate([top, mid, bottom], axis=1)
# Right
top = img[:, 1:2, -1:, :] + img[:, 0:1, -2:-1, :]
inner = img[:, 2:, -1:, :] + img[:, :-2, -2:-1, :]
bottom = img[:, -1:, -1:, :] + img[:, -2:-1, -2:-1, :]
xy_right = K.concatenate([top, inner, bottom], axis=1)
# Xout2
# Left
top = img[:, 0:1, 1:2, :] + img[:, 1:2, 0:1, :]
inner = img[:, :-2, 1:2, :] + img[:, 2:, 0:1, :]
bottom = img[:, -2:-1, 1:2, :] + img[:, -1:, 0:1, :]
xy_left = K.concatenate([top, inner, bottom], axis=1)
# Mid
top = img[:, 0:1, 2:, :] + img[:, 1:2, :-2, :]
mid = img[:, :-2, 2:, :] + img[:, 2:, :-2, :]
bottom = img[:, -2:-1, 2:, :] + img[:, -1:, :-2, :]
xy_mid = K.concatenate([top, mid, bottom], axis=1)
# Right
top = img[:, 0:1, -1:, :] + img[:, 1:2, -2:-1, :]
inner = img[:, :-2, -1:, :] + img[:, 2:, -2:-1, :]
bottom = img[:, -2:-1, -1:, :] + img[:, -1:, -2:-1, :]
xy_right = K.concatenate([top, inner, bottom], axis=1)
xy_out1 = K.concatenate([xy_left, xy_mid, xy_right], axis=2)
xy_out2 = K.concatenate([xy_left, xy_mid, xy_right], axis=2)
return (xy_out1 - xy_out2) * 0.25
class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
""" Laplacian Pyramid Loss Function
Notes
-----
Channels last implementation on square images only.
Parameters
----------
max_levels: int, Optional
The max number of laplacian pyramid levels to use. Default: `5`
gaussian_size: int, Optional
The size of the gaussian kernel. Default: `5`
gaussian_sigma: float, optional
The gaussian sigma. Default: 2.0
References
----------
https://arxiv.org/abs/1707.05776
https://github.com/nathanaelbosch/generative-latent-optimization/blob/master/utils.py
"""
def __init__(self,
max_levels: int = 5,
gaussian_size: int = 5,
gaussian_sigma: float = 1.0) -> None:
self._max_levels = max_levels
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)
@classmethod
def _get_gaussian_kernel(cls, size: int, sigma: float) -> tf.Tensor:
""" Obtain the base gaussian kernel for the Laplacian Pyramid.
Parameters
----------
size: int, Optional
The size of the gaussian kernel
sigma: float
The gaussian sigma
Returns
-------
:class:`tf.Tensor`
The base single channel Gaussian kernel
"""
assert size % 2 == 1, ("kernel size must be uneven")
x_1 = np.linspace(- (size // 2), size // 2, size, dtype="float32")
x_1 /= np.sqrt(2)*sigma
x_2 = x_1 ** 2
kernel = np.exp(- x_2[:, None] - x_2[None, :])
kernel /= kernel.sum()
kernel = np.reshape(kernel, (size, size, 1, 1))
return K.constant(kernel)
def _conv_gaussian(self, inputs: tf.Tensor) -> tf.Tensor:
""" Perform Gaussian convolution on a batch of images.
Parameters
----------
inputs: :class:`tf.Tensor`
The input batch of images to perform Gaussian convolution on.
Returns
-------
:class:`tf.Tensor`
The convolved images
"""
channels = K.int_shape(inputs)[-1]
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
# TF doesn't implement replication padding like pytorch. This is an inefficient way to
# implement it for a square guassian kernel
size = self._gaussian_kernel.shape[1] // 2
padded_inputs = inputs
for _ in range(size):
padded_inputs = tf.pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
([0, 0], [1, 1], [1, 1], [0, 0]),
mode="SYMMETRIC")
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
return retval
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> list[tf.Tensor]:
""" Obtain the Laplacian Pyramid.
Parameters
----------
inputs: :class:`tf.Tensor`
The input batch of images to run through the Laplacian Pyramid
Returns
-------
list
The tensors produced from the Laplacian Pyramid
"""
pyramid = []
current = inputs
for _ in range(self._max_levels):
gauss = self._conv_gaussian(current)
diff = current - gauss
pyramid.append(diff)
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
pyramid.append(current)
return pyramid
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Calculate the Laplacian Pyramid Loss.
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth value
y_pred: :class:`tf.Tensor`
The predicted value
Returns
-------
:class: `tf.Tensor`
The loss value
"""
pyramid_true = self._get_laplacian_pyramid(y_true)
pyramid_pred = self._get_laplacian_pyramid(y_pred)
losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
loss = K.sum(losses * self._weights)
return loss
class LInfNorm(): # pylint:disable=too-few-public-methods
""" Calculate the L-inf norm as a loss function. """
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Call the L-inf norm loss function.
Parameters
----------
y_true: :class:`tf.Tensor`
The ground truth value
y_pred: :class:`tf.Tensor`
The predicted value
Returns
-------
:class:`tf.Tensor`
The loss value
"""
diff = K.abs(y_true - y_pred)
max_loss = K.max(diff, axis=(1, 2), keepdims=True)
loss = K.mean(max_loss, axis=-1)
return loss
class LossWrapper(tf.keras.losses.Loss):
""" A wrapper class for multiple keras losses to enable multiple masked weighted loss
functions on a single output.
Notes
-----
Whilst Keras does allow for applying multiple weighted loss functions, it does not allow
for an easy mechanism to add additional data (in our case masks) that are batch specific
but are not fed in to the model.
This wrapper receives this additional mask data for the batch stacked onto the end of the
color channels of the received :attr:`y_true` batch of images. These masks are then split
off the batch of images and applied to both the :attr:`y_true` and :attr:`y_pred` tensors
prior to feeding into the loss functions.
For example, for an image of shape (4, 128, 128, 3) 3 additional masks may be stacked onto
the end of y_true, meaning we receive an input of shape (4, 128, 128, 6). This wrapper then
splits off (4, 128, 128, 3:6) from the end of the tensor, leaving the original y_true of
shape (4, 128, 128, 3) ready for masking and feeding through the loss functions.
"""
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
super().__init__(name="LossWrapper")
self._loss_functions: list[compile_utils.LossesContainer] = []
self._loss_weights: list[float] = []
self._mask_channels: list[int] = []
logger.debug("Initialized: %s", self.__class__.__name__)
def add_loss(self,
function: Callable,
weight: float = 1.0,
mask_channel: int = -1) -> None:
""" Add the given loss function with the given weight to the loss function chain.
Parameters
----------
function: :class:`tf.keras.losses.Loss`
The loss function to add to the loss chain
weight: float, optional
The weighting to apply to the loss function. Default: `1.0`
mask_channel: int, optional
The channel in the `y_true` image that the mask exists in. Set to `-1` if there is no
mask for the given loss function. Default: `-1`
"""
logger.debug("Adding loss: (function: %s, weight: %s, mask_channel: %s)",
function, weight, mask_channel)
# Loss must be compiled inside LossContainer for keras to handle distibuted strategies
self._loss_functions.append(compile_utils.LossesContainer(function))
self._loss_weights.append(weight)
self._mask_channels.append(mask_channel)
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Call the sub loss functions for the loss wrapper.
Loss is returned as the weighted sum of the chosen losses.
If masks are being applied to the loss function inputs, then they should be included as
additional channels at the end of :attr:`y_true`, so that they can be split off and
applied to the actual inputs to the selected loss function(s).
Parameters
----------
y_true: :class:`tensorflow.Tensor`
The ground truth batch of images, with any required masks stacked on the end
y_pred: :class:`tensorflow.Tensor`
The batch of model predictions
Returns
-------
:class:`tensorflow.Tensor`
The final weighted loss
"""
loss = 0.0
for func, weight, mask_channel in zip(self._loss_functions,
self._loss_weights,
self._mask_channels):
logger.debug("Processing loss function: (func: %s, weight: %s, mask_channel: %s)",
func, weight, mask_channel)
n_true, n_pred = self._apply_mask(y_true, y_pred, mask_channel)
loss += (func(n_true, n_pred) * weight)
return loss
@classmethod
def _apply_mask(cls,
y_true: tf.Tensor,
y_pred: tf.Tensor,
mask_channel: int,
mask_prop: float = 1.0) -> tuple[tf.Tensor, tf.Tensor]:
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
return the unmasked inputs.
Parameters
----------
y_true: tensor or variable
The ground truth value
y_pred: tensor or variable
The predicted value
mask_channel: int
The channel within y_true that the required mask resides in
mask_prop: float, optional
The amount of mask propagation. Default: `1.0`
Returns
-------
tf.Tensor
The ground truth batch of images, with the required mask applied
tf.Tensor
The predicted batch of images with the required mask applied
"""
if mask_channel == -1:
logger.debug("No mask to apply")
return y_true[..., :3], y_pred[..., :3]
logger.debug("Applying mask from channel %s", mask_channel)
mask = K.tile(K.expand_dims(y_true[..., mask_channel], axis=-1), (1, 1, 1, 3))
mask_as_k_inv_prop = 1 - mask_prop
mask = (mask * mask_prop) + mask_as_k_inv_prop
m_true = y_true[..., :3] * mask
m_pred = y_pred[..., :3] * mask
return m_true, m_pred