mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test
403 lines
16 KiB
Python
403 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
""" Custom Feature Map Loss Functions for faceswap.py """
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass, field
|
|
import logging
|
|
import typing as T
|
|
|
|
# Ignore linting errors from Tensorflow's thoroughly broken import system
|
|
import tensorflow as tf
|
|
from tensorflow.keras import applications as kapp # pylint:disable=import-error
|
|
from tensorflow.keras.layers import Dropout, Conv2D, Input, Layer, Resizing # noqa,pylint:disable=no-name-in-module,import-error
|
|
from tensorflow.keras.models import Model # pylint:disable=no-name-in-module,import-error
|
|
import tensorflow.keras.backend as K # pylint:disable=no-name-in-module,import-error
|
|
|
|
import numpy as np
|
|
|
|
from lib.model.nets import AlexNet, SqueezeNet
|
|
from lib.utils import GetModel
|
|
|
|
if T.TYPE_CHECKING:
|
|
from collections.abc import Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class NetInfo:
|
|
""" Data class for holding information about Trunk and Linear Layer nets.
|
|
|
|
Parameters
|
|
----------
|
|
model_id: int
|
|
The model ID for the model stored in the deepfakes Model repo
|
|
model_name: str
|
|
The filename of the decompressed model/weights file
|
|
net: callable, Optional
|
|
The net definition to load, if any. Default:``None``
|
|
init_kwargs: dict, optional
|
|
Keyword arguments to initialize any :attr:`net`. Default: empty ``dict``
|
|
needs_init: bool, optional
|
|
True if the net needs initializing otherwise False. Default: ``True``
|
|
"""
|
|
model_id: int = 0
|
|
model_name: str = ""
|
|
net: Callable | None = None
|
|
init_kwargs: dict[str, T.Any] = field(default_factory=dict)
|
|
needs_init: bool = True
|
|
outputs: list[Layer] = field(default_factory=list)
|
|
|
|
|
|
class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
|
|
""" Trunk neural network loader for LPIPS Loss function.
|
|
|
|
Parameters
|
|
----------
|
|
net_name: str
|
|
The name of the trunk network to load. One of "alex", "squeeze" or "vgg16"
|
|
eval_mode: bool
|
|
``True`` for evaluation mode, ``False`` for training mode
|
|
load_weights: bool
|
|
``True`` if pretrained trunk network weights should be loaded, otherwise ``False``
|
|
"""
|
|
def __init__(self, net_name: str, eval_mode: bool, load_weights: bool) -> None:
|
|
logger.debug("Initializing: %s (net_name '%s', eval_mode: %s, load_weights: %s)",
|
|
self.__class__.__name__, net_name, eval_mode, load_weights)
|
|
self._eval_mode = eval_mode
|
|
self._load_weights = load_weights
|
|
self._net_name = net_name
|
|
self._net = self._nets[net_name]
|
|
logger.debug("Initialized: %s ", self.__class__.__name__)
|
|
|
|
@property
|
|
def _nets(self) -> dict[str, NetInfo]:
|
|
""" :class:`NetInfo`: The Information about the requested net."""
|
|
return {
|
|
"alex": NetInfo(model_id=15,
|
|
model_name="alexnet_imagenet_no_top_v1.h5",
|
|
net=AlexNet,
|
|
outputs=[f"features.{idx}" for idx in (0, 3, 6, 8, 10)]),
|
|
"squeeze": NetInfo(model_id=16,
|
|
model_name="squeezenet_imagenet_no_top_v1.h5",
|
|
net=SqueezeNet,
|
|
outputs=[f"features.{idx}" for idx in (0, 4, 7, 9, 10, 11, 12)]),
|
|
"vgg16": NetInfo(model_id=17,
|
|
model_name="vgg16_imagenet_no_top_v1.h5",
|
|
net=kapp.vgg16.VGG16,
|
|
init_kwargs={"include_top": False, "weights": None},
|
|
outputs=[f"block{i + 1}_conv{2 if i < 2 else 3}" for i in range(5)])}
|
|
|
|
@classmethod
|
|
def _normalize_output(cls, inputs: tf.Tensor, epsilon: float = 1e-10) -> tf.Tensor:
|
|
""" Normalize the output tensors from the trunk network.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: :class:`tensorflow.Tensor`
|
|
An output tensor from the trunk model
|
|
epsilon: float, optional
|
|
Epsilon to apply to the normalization operation. Default: `1e-10`
|
|
"""
|
|
norm_factor = K.sqrt(K.sum(K.square(inputs), axis=-1, keepdims=True))
|
|
return inputs / (norm_factor + epsilon)
|
|
|
|
def _process_weights(self, model: Model) -> Model:
|
|
""" Save and lock weights if requested.
|
|
|
|
Parameters
|
|
----------
|
|
model :class:`keras.models.Model`
|
|
The loaded trunk or linear network
|
|
|
|
Returns
|
|
-------
|
|
:class:`keras.models.Model`
|
|
The network with weights loaded/not loaded and layers locked/unlocked
|
|
"""
|
|
if self._load_weights:
|
|
weights = GetModel(self._net.model_name, self._net.model_id).model_path
|
|
model.load_weights(weights)
|
|
|
|
if self._eval_mode:
|
|
model.trainable = False
|
|
for layer in model.layers:
|
|
layer.trainable = False
|
|
return model
|
|
|
|
def __call__(self) -> Model:
|
|
""" Load the Trunk net, add normalization to feature outputs, load weights and set
|
|
trainable state.
|
|
|
|
Returns
|
|
-------
|
|
:class:`tensorflow.keras.models.Model`
|
|
The trunk net with normalized feature output layers
|
|
"""
|
|
if self._net.net is None:
|
|
raise ValueError("No net loaded")
|
|
|
|
model = self._net.net(**self._net.init_kwargs)
|
|
model = model if self._net_name == "vgg16" else model()
|
|
out_layers = [self._normalize_output(model.get_layer(name).output)
|
|
for name in self._net.outputs]
|
|
model = Model(inputs=model.input, outputs=out_layers)
|
|
model = self._process_weights(model)
|
|
return model
|
|
|
|
|
|
class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
|
|
""" The Linear Network to be applied to the difference between the true and predicted outputs
|
|
of the trunk network.
|
|
|
|
Parameters
|
|
----------
|
|
net_name: str
|
|
The name of the trunk network in use. One of "alex", "squeeze" or "vgg16"
|
|
eval_mode: bool
|
|
``True`` for evaluation mode, ``False`` for training mode
|
|
load_weights: bool
|
|
``True`` if pretrained linear network weights should be loaded, otherwise ``False``
|
|
trunk_net: :class:`keras.models.Model`
|
|
The trunk net to place the linear layer on.
|
|
use_dropout: bool
|
|
``True`` if a dropout layer should be used in the Linear network otherwise ``False``
|
|
"""
|
|
def __init__(self,
|
|
net_name: str,
|
|
eval_mode: bool,
|
|
load_weights: bool,
|
|
trunk_net: Model,
|
|
use_dropout: bool) -> None:
|
|
logger.debug(
|
|
"Initializing: %s (trunk_net: %s, use_dropout: %s)", self.__class__.__name__,
|
|
trunk_net, use_dropout)
|
|
super().__init__(net_name=net_name, eval_mode=eval_mode, load_weights=load_weights)
|
|
|
|
self._trunk = trunk_net
|
|
self._use_dropout = use_dropout
|
|
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def _nets(self) -> dict[str, NetInfo]:
|
|
""" :class:`NetInfo`: The Information about the requested net."""
|
|
return {
|
|
"alex": NetInfo(model_id=18,
|
|
model_name="alexnet_lpips_v1.h5",),
|
|
"squeeze": NetInfo(model_id=19,
|
|
model_name="squeezenet_lpips_v1.h5"),
|
|
"vgg16": NetInfo(model_id=20,
|
|
model_name="vgg16_lpips_v1.h5")}
|
|
|
|
def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
|
""" Build a linear block for a trunk network output.
|
|
|
|
Parameters
|
|
----------
|
|
net_output_layer: :class:`tensorflow.Tensor`
|
|
An output from the selected trunk network
|
|
|
|
Returns
|
|
-------
|
|
:class:`tensorflow.Tensor`
|
|
The input to the linear block
|
|
:class:`tensorflow.Tensor`
|
|
The output from the linear block
|
|
"""
|
|
in_shape = K.int_shape(net_output_layer)[1:]
|
|
input_ = Input(in_shape)
|
|
var_x = Dropout(rate=0.5)(input_) if self._use_dropout else input_
|
|
var_x = Conv2D(1, 1, strides=1, padding="valid", use_bias=False)(var_x)
|
|
return input_, var_x
|
|
|
|
def __call__(self) -> Model:
|
|
""" Build the linear network for the given trunk network's outputs. Load in trained weights
|
|
and set the model's trainable parameters.
|
|
|
|
Returns
|
|
-------
|
|
:class:`tensorflow.keras.models.Model`
|
|
The compiled Linear Net model
|
|
"""
|
|
inputs = []
|
|
outputs = []
|
|
|
|
for input_ in self._trunk.outputs:
|
|
in_, out = self._linear_block(input_)
|
|
inputs.append(in_)
|
|
outputs.append(out)
|
|
|
|
model = Model(inputs=inputs, outputs=outputs)
|
|
model = self._process_weights(model)
|
|
return model
|
|
|
|
|
|
class LPIPSLoss(): # pylint:disable=too-few-public-methods
|
|
""" LPIPS Loss Function.
|
|
|
|
A perceptual loss function that uses linear outputs from pretrained CNNs feature layers.
|
|
|
|
Notes
|
|
-----
|
|
Channels Last implementation. All trunks implemented from the original paper.
|
|
|
|
References
|
|
----------
|
|
https://richzhang.github.io/PerceptualSimilarity/
|
|
|
|
Parameters
|
|
----------
|
|
trunk_network: str
|
|
The name of the trunk network to use. One of "alex", "squeeze" or "vgg16"
|
|
trunk_pretrained: bool, optional
|
|
``True`` Load the imagenet pretrained weights for the trunk network. ``False`` randomly
|
|
initialize the trunk network. Default: ``True``
|
|
trunk_eval_mode: bool, optional
|
|
``True`` for running inference on the trunk network (standard mode), ``False`` for training
|
|
the trunk network. Default: ``True``
|
|
linear_pretrained: bool, optional
|
|
``True`` loads the pretrained weights for the linear network layers. ``False`` randomly
|
|
initializes the layers. Default: ``True``
|
|
linear_eval_mode: bool, optional
|
|
``True`` for running inference on the linear network (standard mode), ``False`` for
|
|
training the linear network. Default: ``True``
|
|
linear_use_dropout: bool, optional
|
|
``True`` if a dropout layer should be used in the Linear network otherwise ``False``.
|
|
Default: ``True``
|
|
lpips: bool, optional
|
|
``True`` to use linear network on top of the trunk network. ``False`` to just average the
|
|
output from the trunk network. Default ``True``
|
|
spatial: bool, optional
|
|
``True`` output the loss in the spatial domain (i.e. as a grayscale tensor of height and
|
|
width of the input image). ``Bool`` reduce the spatial dimensions for loss calculation.
|
|
Default: ``False``
|
|
normalize: bool, optional
|
|
``True`` if the input Tensor needs to be normalized from the 0. to 1. range to the -1. to
|
|
1. range. Default: ``True``
|
|
ret_per_layer: bool, optional
|
|
``True`` to return the loss value per feature output layer otherwise ``False``.
|
|
Default: ``False``
|
|
"""
|
|
def __init__(self, # pylint:disable=too-many-arguments
|
|
trunk_network: str,
|
|
trunk_pretrained: bool = True,
|
|
trunk_eval_mode: bool = True,
|
|
linear_pretrained: bool = True,
|
|
linear_eval_mode: bool = True,
|
|
linear_use_dropout: bool = True,
|
|
lpips: bool = True,
|
|
spatial: bool = False,
|
|
normalize: bool = True,
|
|
ret_per_layer: bool = False) -> None:
|
|
logger.debug(
|
|
"Initializing: %s (trunk_network '%s', trunk_pretrained: %s, trunk_eval_mode: %s, "
|
|
"linear_pretrained: %s, linear_eval_mode: %s, linear_use_dropout: %s, lpips: %s, "
|
|
"spatial: %s, normalize: %s, ret_per_layer: %s)", self.__class__.__name__,
|
|
trunk_network, trunk_pretrained, trunk_eval_mode, linear_pretrained, linear_eval_mode,
|
|
linear_use_dropout, lpips, spatial, normalize, ret_per_layer)
|
|
|
|
self._spatial = spatial
|
|
self._use_lpips = lpips
|
|
self._normalize = normalize
|
|
self._ret_per_layer = ret_per_layer
|
|
self._shift = K.constant(np.array([-.030, -.088, -.188],
|
|
dtype="float32")[None, None, None, :])
|
|
self._scale = K.constant(np.array([.458, .448, .450],
|
|
dtype="float32")[None, None, None, :])
|
|
|
|
# Loss needs to be done as fp32. We could cast at output, but better to update the model
|
|
switch_mixed_precision = tf.keras.mixed_precision.global_policy().name == "mixed_float16"
|
|
if switch_mixed_precision:
|
|
logger.debug("Temporarily disabling mixed precision")
|
|
tf.keras.mixed_precision.set_global_policy("float32")
|
|
|
|
self._trunk_net = _LPIPSTrunkNet(trunk_network, trunk_eval_mode, trunk_pretrained)()
|
|
self._linear_net = _LPIPSLinearNet(trunk_network,
|
|
linear_eval_mode,
|
|
linear_pretrained,
|
|
self._trunk_net,
|
|
linear_use_dropout)()
|
|
if switch_mixed_precision:
|
|
logger.debug("Re-enabling mixed precision")
|
|
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]:
|
|
""" Perform processing on the Trunk Network outputs.
|
|
|
|
If :attr:`use_ldip` is enabled, process the diff values through the linear network,
|
|
otherwise return the diff values summed on the channels axis.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: list
|
|
List of the squared difference of the true and predicted outputs from the trunk network
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
List of either the linear network outputs (when using lpips) or summed network outputs
|
|
"""
|
|
if self._use_lpips:
|
|
return self._linear_net(inputs)
|
|
return [K.sum(x, axis=-1) for x in inputs]
|
|
|
|
def _process_output(self, inputs: tf.Tensor, output_dims: tuple) -> tf.Tensor:
|
|
""" Process an individual output based on whether :attr:`is_spatial` has been selected.
|
|
|
|
When spatial output is selected, all outputs are sized to the shape of the original True
|
|
input Tensor. When not selected, the mean across the spatial axes (h, w) are returned
|
|
|
|
Parameters
|
|
----------
|
|
inputs: :class:`tensorflow.Tensor`
|
|
An individual diff output tensor from the linear network or summed output
|
|
output_dims: tuple
|
|
The (height, width) of the original true image
|
|
|
|
Returns
|
|
-------
|
|
:class:`tensorflow.Tensor`
|
|
Either the original tensor resized to the true image dimensions, or the mean
|
|
value across the height, width axes.
|
|
"""
|
|
if self._spatial:
|
|
return Resizing(*output_dims, interpolation="bilinear")(inputs)
|
|
return K.mean(inputs, axis=(1, 2), keepdims=True)
|
|
|
|
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
|
""" Perform the LPIPS Loss Function.
|
|
|
|
Parameters
|
|
----------
|
|
y_true: :class:`tensorflow.Tensor`
|
|
The ground truth batch of images
|
|
y_pred: :class:`tensorflow.Tensor`
|
|
The predicted batch of images
|
|
|
|
Returns
|
|
-------
|
|
:class:`tensorflow.Tensor`
|
|
The final loss value
|
|
"""
|
|
if self._normalize:
|
|
y_true = (y_true * 2.0) - 1.0
|
|
y_pred = (y_pred * 2.0) - 1.0
|
|
|
|
y_true = (y_true - self._shift) / self._scale
|
|
y_pred = (y_pred - self._shift) / self._scale
|
|
|
|
net_true = self._trunk_net(y_true)
|
|
net_pred = self._trunk_net(y_pred)
|
|
|
|
diffs = [(out_true - out_pred) ** 2
|
|
for out_true, out_pred in zip(net_true, net_pred)]
|
|
|
|
dims = K.int_shape(y_true)[1:3]
|
|
res = [self._process_output(diff, dims) for diff in self._process_diffs(diffs)]
|
|
|
|
axis = 0 if self._spatial else None
|
|
val = K.sum(res, axis=axis)
|
|
|
|
retval = (val, res) if self._ret_per_layer else val
|
|
return retval / 10.0 # Reduce by factor of 10 'cos this loss is STRONG
|