1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/model/losses/feature_loss.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

403 lines
16 KiB
Python

#!/usr/bin/env python3
""" Custom Feature Map Loss Functions for faceswap.py """
from __future__ import annotations
from dataclasses import dataclass, field
import logging
import typing as T
# Ignore linting errors from Tensorflow's thoroughly broken import system
import tensorflow as tf
from tensorflow.keras import applications as kapp # pylint:disable=import-error
from tensorflow.keras.layers import Dropout, Conv2D, Input, Layer, Resizing # noqa,pylint:disable=no-name-in-module,import-error
from tensorflow.keras.models import Model # pylint:disable=no-name-in-module,import-error
import tensorflow.keras.backend as K # pylint:disable=no-name-in-module,import-error
import numpy as np
from lib.model.nets import AlexNet, SqueezeNet
from lib.utils import GetModel
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__)
@dataclass
class NetInfo:
""" Data class for holding information about Trunk and Linear Layer nets.
Parameters
----------
model_id: int
The model ID for the model stored in the deepfakes Model repo
model_name: str
The filename of the decompressed model/weights file
net: callable, Optional
The net definition to load, if any. Default:``None``
init_kwargs: dict, optional
Keyword arguments to initialize any :attr:`net`. Default: empty ``dict``
needs_init: bool, optional
True if the net needs initializing otherwise False. Default: ``True``
"""
model_id: int = 0
model_name: str = ""
net: Callable | None = None
init_kwargs: dict[str, T.Any] = field(default_factory=dict)
needs_init: bool = True
outputs: list[Layer] = field(default_factory=list)
class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
""" Trunk neural network loader for LPIPS Loss function.
Parameters
----------
net_name: str
The name of the trunk network to load. One of "alex", "squeeze" or "vgg16"
eval_mode: bool
``True`` for evaluation mode, ``False`` for training mode
load_weights: bool
``True`` if pretrained trunk network weights should be loaded, otherwise ``False``
"""
def __init__(self, net_name: str, eval_mode: bool, load_weights: bool) -> None:
logger.debug("Initializing: %s (net_name '%s', eval_mode: %s, load_weights: %s)",
self.__class__.__name__, net_name, eval_mode, load_weights)
self._eval_mode = eval_mode
self._load_weights = load_weights
self._net_name = net_name
self._net = self._nets[net_name]
logger.debug("Initialized: %s ", self.__class__.__name__)
@property
def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net."""
return {
"alex": NetInfo(model_id=15,
model_name="alexnet_imagenet_no_top_v1.h5",
net=AlexNet,
outputs=[f"features.{idx}" for idx in (0, 3, 6, 8, 10)]),
"squeeze": NetInfo(model_id=16,
model_name="squeezenet_imagenet_no_top_v1.h5",
net=SqueezeNet,
outputs=[f"features.{idx}" for idx in (0, 4, 7, 9, 10, 11, 12)]),
"vgg16": NetInfo(model_id=17,
model_name="vgg16_imagenet_no_top_v1.h5",
net=kapp.vgg16.VGG16,
init_kwargs={"include_top": False, "weights": None},
outputs=[f"block{i + 1}_conv{2 if i < 2 else 3}" for i in range(5)])}
@classmethod
def _normalize_output(cls, inputs: tf.Tensor, epsilon: float = 1e-10) -> tf.Tensor:
""" Normalize the output tensors from the trunk network.
Parameters
----------
inputs: :class:`tensorflow.Tensor`
An output tensor from the trunk model
epsilon: float, optional
Epsilon to apply to the normalization operation. Default: `1e-10`
"""
norm_factor = K.sqrt(K.sum(K.square(inputs), axis=-1, keepdims=True))
return inputs / (norm_factor + epsilon)
def _process_weights(self, model: Model) -> Model:
""" Save and lock weights if requested.
Parameters
----------
model :class:`keras.models.Model`
The loaded trunk or linear network
Returns
-------
:class:`keras.models.Model`
The network with weights loaded/not loaded and layers locked/unlocked
"""
if self._load_weights:
weights = GetModel(self._net.model_name, self._net.model_id).model_path
model.load_weights(weights)
if self._eval_mode:
model.trainable = False
for layer in model.layers:
layer.trainable = False
return model
def __call__(self) -> Model:
""" Load the Trunk net, add normalization to feature outputs, load weights and set
trainable state.
Returns
-------
:class:`tensorflow.keras.models.Model`
The trunk net with normalized feature output layers
"""
if self._net.net is None:
raise ValueError("No net loaded")
model = self._net.net(**self._net.init_kwargs)
model = model if self._net_name == "vgg16" else model()
out_layers = [self._normalize_output(model.get_layer(name).output)
for name in self._net.outputs]
model = Model(inputs=model.input, outputs=out_layers)
model = self._process_weights(model)
return model
class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
""" The Linear Network to be applied to the difference between the true and predicted outputs
of the trunk network.
Parameters
----------
net_name: str
The name of the trunk network in use. One of "alex", "squeeze" or "vgg16"
eval_mode: bool
``True`` for evaluation mode, ``False`` for training mode
load_weights: bool
``True`` if pretrained linear network weights should be loaded, otherwise ``False``
trunk_net: :class:`keras.models.Model`
The trunk net to place the linear layer on.
use_dropout: bool
``True`` if a dropout layer should be used in the Linear network otherwise ``False``
"""
def __init__(self,
net_name: str,
eval_mode: bool,
load_weights: bool,
trunk_net: Model,
use_dropout: bool) -> None:
logger.debug(
"Initializing: %s (trunk_net: %s, use_dropout: %s)", self.__class__.__name__,
trunk_net, use_dropout)
super().__init__(net_name=net_name, eval_mode=eval_mode, load_weights=load_weights)
self._trunk = trunk_net
self._use_dropout = use_dropout
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net."""
return {
"alex": NetInfo(model_id=18,
model_name="alexnet_lpips_v1.h5",),
"squeeze": NetInfo(model_id=19,
model_name="squeezenet_lpips_v1.h5"),
"vgg16": NetInfo(model_id=20,
model_name="vgg16_lpips_v1.h5")}
def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
""" Build a linear block for a trunk network output.
Parameters
----------
net_output_layer: :class:`tensorflow.Tensor`
An output from the selected trunk network
Returns
-------
:class:`tensorflow.Tensor`
The input to the linear block
:class:`tensorflow.Tensor`
The output from the linear block
"""
in_shape = K.int_shape(net_output_layer)[1:]
input_ = Input(in_shape)
var_x = Dropout(rate=0.5)(input_) if self._use_dropout else input_
var_x = Conv2D(1, 1, strides=1, padding="valid", use_bias=False)(var_x)
return input_, var_x
def __call__(self) -> Model:
""" Build the linear network for the given trunk network's outputs. Load in trained weights
and set the model's trainable parameters.
Returns
-------
:class:`tensorflow.keras.models.Model`
The compiled Linear Net model
"""
inputs = []
outputs = []
for input_ in self._trunk.outputs:
in_, out = self._linear_block(input_)
inputs.append(in_)
outputs.append(out)
model = Model(inputs=inputs, outputs=outputs)
model = self._process_weights(model)
return model
class LPIPSLoss(): # pylint:disable=too-few-public-methods
""" LPIPS Loss Function.
A perceptual loss function that uses linear outputs from pretrained CNNs feature layers.
Notes
-----
Channels Last implementation. All trunks implemented from the original paper.
References
----------
https://richzhang.github.io/PerceptualSimilarity/
Parameters
----------
trunk_network: str
The name of the trunk network to use. One of "alex", "squeeze" or "vgg16"
trunk_pretrained: bool, optional
``True`` Load the imagenet pretrained weights for the trunk network. ``False`` randomly
initialize the trunk network. Default: ``True``
trunk_eval_mode: bool, optional
``True`` for running inference on the trunk network (standard mode), ``False`` for training
the trunk network. Default: ``True``
linear_pretrained: bool, optional
``True`` loads the pretrained weights for the linear network layers. ``False`` randomly
initializes the layers. Default: ``True``
linear_eval_mode: bool, optional
``True`` for running inference on the linear network (standard mode), ``False`` for
training the linear network. Default: ``True``
linear_use_dropout: bool, optional
``True`` if a dropout layer should be used in the Linear network otherwise ``False``.
Default: ``True``
lpips: bool, optional
``True`` to use linear network on top of the trunk network. ``False`` to just average the
output from the trunk network. Default ``True``
spatial: bool, optional
``True`` output the loss in the spatial domain (i.e. as a grayscale tensor of height and
width of the input image). ``Bool`` reduce the spatial dimensions for loss calculation.
Default: ``False``
normalize: bool, optional
``True`` if the input Tensor needs to be normalized from the 0. to 1. range to the -1. to
1. range. Default: ``True``
ret_per_layer: bool, optional
``True`` to return the loss value per feature output layer otherwise ``False``.
Default: ``False``
"""
def __init__(self, # pylint:disable=too-many-arguments
trunk_network: str,
trunk_pretrained: bool = True,
trunk_eval_mode: bool = True,
linear_pretrained: bool = True,
linear_eval_mode: bool = True,
linear_use_dropout: bool = True,
lpips: bool = True,
spatial: bool = False,
normalize: bool = True,
ret_per_layer: bool = False) -> None:
logger.debug(
"Initializing: %s (trunk_network '%s', trunk_pretrained: %s, trunk_eval_mode: %s, "
"linear_pretrained: %s, linear_eval_mode: %s, linear_use_dropout: %s, lpips: %s, "
"spatial: %s, normalize: %s, ret_per_layer: %s)", self.__class__.__name__,
trunk_network, trunk_pretrained, trunk_eval_mode, linear_pretrained, linear_eval_mode,
linear_use_dropout, lpips, spatial, normalize, ret_per_layer)
self._spatial = spatial
self._use_lpips = lpips
self._normalize = normalize
self._ret_per_layer = ret_per_layer
self._shift = K.constant(np.array([-.030, -.088, -.188],
dtype="float32")[None, None, None, :])
self._scale = K.constant(np.array([.458, .448, .450],
dtype="float32")[None, None, None, :])
# Loss needs to be done as fp32. We could cast at output, but better to update the model
switch_mixed_precision = tf.keras.mixed_precision.global_policy().name == "mixed_float16"
if switch_mixed_precision:
logger.debug("Temporarily disabling mixed precision")
tf.keras.mixed_precision.set_global_policy("float32")
self._trunk_net = _LPIPSTrunkNet(trunk_network, trunk_eval_mode, trunk_pretrained)()
self._linear_net = _LPIPSLinearNet(trunk_network,
linear_eval_mode,
linear_pretrained,
self._trunk_net,
linear_use_dropout)()
if switch_mixed_precision:
logger.debug("Re-enabling mixed precision")
tf.keras.mixed_precision.set_global_policy("mixed_float16")
logger.debug("Initialized: %s", self.__class__.__name__)
def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]:
""" Perform processing on the Trunk Network outputs.
If :attr:`use_ldip` is enabled, process the diff values through the linear network,
otherwise return the diff values summed on the channels axis.
Parameters
----------
inputs: list
List of the squared difference of the true and predicted outputs from the trunk network
Returns
-------
list
List of either the linear network outputs (when using lpips) or summed network outputs
"""
if self._use_lpips:
return self._linear_net(inputs)
return [K.sum(x, axis=-1) for x in inputs]
def _process_output(self, inputs: tf.Tensor, output_dims: tuple) -> tf.Tensor:
""" Process an individual output based on whether :attr:`is_spatial` has been selected.
When spatial output is selected, all outputs are sized to the shape of the original True
input Tensor. When not selected, the mean across the spatial axes (h, w) are returned
Parameters
----------
inputs: :class:`tensorflow.Tensor`
An individual diff output tensor from the linear network or summed output
output_dims: tuple
The (height, width) of the original true image
Returns
-------
:class:`tensorflow.Tensor`
Either the original tensor resized to the true image dimensions, or the mean
value across the height, width axes.
"""
if self._spatial:
return Resizing(*output_dims, interpolation="bilinear")(inputs)
return K.mean(inputs, axis=(1, 2), keepdims=True)
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
""" Perform the LPIPS Loss Function.
Parameters
----------
y_true: :class:`tensorflow.Tensor`
The ground truth batch of images
y_pred: :class:`tensorflow.Tensor`
The predicted batch of images
Returns
-------
:class:`tensorflow.Tensor`
The final loss value
"""
if self._normalize:
y_true = (y_true * 2.0) - 1.0
y_pred = (y_pred * 2.0) - 1.0
y_true = (y_true - self._shift) / self._scale
y_pred = (y_pred - self._shift) / self._scale
net_true = self._trunk_net(y_true)
net_pred = self._trunk_net(y_pred)
diffs = [(out_true - out_pred) ** 2
for out_true, out_pred in zip(net_true, net_pred)]
dims = K.int_shape(y_true)[1:3]
res = [self._process_output(diff, dims) for diff in self._process_diffs(diffs)]
axis = 0 if self._spatial else None
val = K.sum(res, axis=axis)
retval = (val, res) if self._ret_per_layer else val
return retval / 10.0 # Reduce by factor of 10 'cos this loss is STRONG