1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/lib/model/initializers.py
torzdf d8557c1970
Faceswap 2.0 (#1045)
* Core Updates
    - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
    - Document lib.gpu_stats and lib.sys_info
    - Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
    - lib.gui.menu - typofix

* Update Dependencies
Bump Tensorflow Version Check

* Port extraction to tf2

* Add custom import finder for loading Keras or tf.keras depending on backend

* Add `tensorflow` to KerasFinder search path

* Basic TF2 training running

* model.initializers - docstring fix

* Fix and pass tests for tf2

* Replace Keras backend tests with faceswap backend tests

* Initial optimizers update

* Monkey patch tf.keras optimizer

* Remove custom Adam Optimizers and Memory Saving Gradients

* Remove multi-gpu option. Add Distribution to cli

* plugins.train.model._base: Add Mirror, Central and Default distribution strategies

* Update tensorboard kwargs for tf2

* Penalized Loss - Fix for TF2 and AMD

* Fix syntax for tf2.1

* requirements typo fix

* Explicit None for clipnorm if using a distribution strategy

* Fix penalized loss for distribution strategies

* Update Dlight

* typo fix

* Pin to TF2.2

* setup.py - Install tensorflow from pip if not available in Conda

* Add reduction options and set default for mirrored distribution strategy

* Explicitly use default strategy rather than nullcontext

* lib.model.backup_restore documentation

* Remove mirrored strategy reduction method and default based on OS

* Initial restructure - training

* Remove PingPong
Start model.base refactor

* Model saving and resuming enabled

* More tidying up of model.base

* Enable backup and snapshotting

* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly

* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers

* Apply custom Conv2D layer

* Finalize NNBlock restructure
Update Dfaker blocks

* Fix reloading model under a different distribution strategy

* Pass command line arguments through to trainer

* Remove training_opts from model and reference params directly

* Tidy up model __init__

* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning

* Fix timelapse

* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original

* dfl-h128 ported

* DFL SAE ported

* IAE Ported

* dlight ported

* port lightweight

* realface ported

* unbalanced ported

* villain ported

* lib.cli.args - Update Batchsize + move allow_growth to config

* Remove output shape definition
Get image sizes per side rather than globally

* Strip mask input from encoder

* Fix learn mask and output learned mask to preview

* Trigger Allow Growth prior to setting strategy

* Fix GUI Graphing

* GUI - Display batchsize correctly + fix training graphs

* Fix penalized loss

* Enable mixed precision training

* Update analysis displayed batch to match input

* Penalized Loss - Multi-GPU Fix

* Fix all losses for TF2

* Fix Reflect Padding

* Allow different input size for each side of the model

* Fix conv-aware initialization on reload

* Switch allow_growth order

* Move mixed_precision to cli

* Remove distrubution strategies

* Compile penalized loss sub-function into LossContainer

* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels

* Add ability to refresh preview on demand on pop-up window

* Enable refresh of training preview from GUI

* Fix Convert
Debug logging in Initializers

* Fix Preview Tool

* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs

* lib.gui.popup_configure - Make more responsive + document

* Multiple Outputs supported in trainer
Original Model - Mask output bugfix

* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)

* Fix inference model to work properly with all models

* Fix multi-scale output for convert

* Fix clipnorm issue with distribution strategies
Edit error message on OOM

* Update plaidml losses

* Add missing file

* Disable gmsd loss for plaidnl

* PlaidML - Basic training working

* clipnorm rewriting for mixed-precision

* Inference model creation bugfixes

* Remove debug code

* Bugfix: Default clipnorm to 1.0

* Remove all mask inputs from training code

* Remove mask inputs from convert

* GUI - Analysis Tab - Docstrings

* Fix rate in totals row

* lib.gui - Only update display pages if they have focus

* Save the model on first iteration

* plaidml - Fix SSIM loss with penalized loss

* tools.alignments - Remove manual and fix jobs

* GUI - Remove case formatting on help text

* gui MultiSelect custom widget - Set default values on init

* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded

* Cascade excluded devices to GPU Stats

* Explicit GPU selection for Train and Convert

* Reduce Tensorflow Min GPU Multiprocessor Count to 4

* remove compat.v1 code from extract

* Force TF to skip mixed precision compatibility check if GPUs have been filtered

* Add notes to config for non-working AMD losses

* Rasie error if forcing extract to CPU mode

* Fix loading of legace dfl-sae weights + dfl-sae typo fix

* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations

* docs: lib.gui.display

* clipnorm amd condition check

* documentation - gui.display_analysis

* Documentation - gui.popup_configure

* Documentation - lib.logger

* Documentation - lib.model.initializers

* Documentation - lib.model.layers

* Documentation - lib.model.losses

* Documentation - lib.model.nn_blocks

* Documetation - lib.model.normalization

* Documentation - lib.model.session

* Documentation - lib.plaidml_stats

* Documentation: lib.training_data

* Documentation: lib.utils

* Documentation: plugins.train.model._base

* GUI Stats: prevent stats from using GPU

* Documentation - Original Model

* Documentation: plugins.model.trainer._base

* linting

* unit tests: initializers + losses

* unit tests: nn_blocks

* bugfix - Exclude gpu devices in train, not include

* Enable Exclude-Gpus in Extract

* Enable exclude gpus in tools

* Disallow multiple plugin types in a single model folder

* Automatically add exclude_gpus argument in for cpu backends

* Cpu backend fixes

* Relax optimizer test threshold

* Default Train settings - Set mask to Extended

* Update Extractor cli help text
Update to Python 3.8

* Fix FAN to run on CPU

* lib.plaidml_tools - typofix

* Linux installer - check for curl

* linux installer - typo fix
2020-08-12 10:36:41 +01:00

322 lines
11 KiB
Python

#!/usr/bin/env python3
""" Custom Initializers for faceswap.py """
import logging
import sys
import inspect
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras import initializers
from keras.utils import get_custom_objects
from lib.utils import get_backend
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def compute_fans(shape, data_format='channels_last'):
"""Computes the number of input and output units for a weight shape.
Ported directly from Keras as the location moves between keras and tensorflow-keras
Parameters
----------
shape: tuple
shape tuple of integers
data_format: str
Image data format to use for convolution kernels. Note that all kernels in Keras are
standardized on the `"channels_last"` ordering (even when inputs are set to
`"channels_first"`).
Returns
-------
tuple
A tuple of scalars, `(fan_in, fan_out)`.
Raises
------
ValueError
In case of invalid `data_format` argument.
"""
if len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
elif len(shape) in {3, 4, 5}:
# Assuming convolution kernels (1D, 2D or 3D).
# Theano kernel shape: (depth, input_depth, ...)
# Tensorflow kernel shape: (..., input_depth, depth)
if data_format == 'channels_first':
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size
elif data_format == 'channels_last':
receptive_field_size = np.prod(shape[:-2])
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
else:
raise ValueError('Invalid data_format: ' + data_format)
else:
# No specific assumptions.
fan_in = np.sqrt(np.prod(shape))
fan_out = np.sqrt(np.prod(shape))
return fan_in, fan_out
class ICNR(initializers.Initializer): # pylint: disable=invalid-name
""" ICNR initializer for checkerboard artifact free sub pixel convolution
Parameters
----------
initializer: :class:`keras.initializers.Initializer`
The initializer used for sub kernels (orthogonal, glorot uniform, etc.)
scale: int, optional
scaling factor of sub pixel convolution (up sampling from 8x8 to 16x16 is scale 2).
Default: `2`
Returns
-------
tensor
The modified kernel weights
Example
-------
>>> x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2))
References
----------
Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution
https://arxiv.org/pdf/1707.02937.pdf, https://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, initializer, scale=2):
self.scale = scale
self.initializer = initializer
def __call__(self, shape, dtype="float32"):
""" Call function for the ICNR initializer.
Parameters
----------
shape: tuple or list
The required resized shape for the output tensor
dtype: str
The data type for the tensor
Returns
-------
tensor
The modified kernel weights
"""
shape = list(shape)
if self.scale == 1:
return self.initializer(shape)
new_shape = shape[:3] + [shape[3] // (self.scale ** 2)]
if isinstance(self.initializer, dict):
self.initializer = initializers.deserialize(self.initializer)
var_x = self.initializer(new_shape, dtype)
var_x = K.permute_dimensions(var_x, [2, 0, 1, 3])
var_x = K.resize_images(var_x,
self.scale,
self.scale,
"channels_last",
interpolation="nearest")
var_x = self._space_to_depth(var_x)
var_x = K.permute_dimensions(var_x, [1, 2, 0, 3])
logger.debug("Output shape: %s", var_x.shape)
return var_x
def _space_to_depth(self, input_tensor):
""" Space to depth implementation.
PlaidML does not have a space to depth operation, so calculate if backend is amd
otherwise returns the :func:`tensorflow.space_to_depth` operation.
Parameters
----------
input_tensor: tensor
The tensor to be manipulated
Returns
-------
tensor
The manipulated input tensor
"""
if get_backend() == "amd":
batch, height, width, depth = input_tensor.shape.dims
new_height = height // self.scale
new_width = width // self.scale
reshaped = K.reshape(input_tensor,
(batch, new_height, self.scale, new_width, self.scale, depth))
retval = K.reshape(K.permute_dimensions(reshaped, [0, 1, 3, 2, 4, 5]),
(batch, new_height, new_width, -1))
else:
retval = tf.nn.space_to_depth(input_tensor, block_size=self.scale, data_format="NHWC")
logger.debug("Input shape: %s, Output shape: %s", input_tensor.shape, retval.shape)
return retval
def get_config(self):
""" Return the ICNR Initializer configuration.
Returns
-------
dict
The configuration for ICNR Initialization
"""
config = {"scale": self.scale,
"initializer": self.initializer
}
base_config = super(ICNR, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ConvolutionAware(initializers.Initializer):
"""
Initializer that generates orthogonal convolution filters in the Fourier space. If this
initializer is passed a shape that is not 3D or 4D, orthogonal initialization will be used.
Adapted, fixed and optimized from:
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/initializers/convaware.py
Parameters
----------
eps_std: float, optional
The Standard deviation for the random normal noise used to break symmetry in the inverse
Fourier transform. Default: 0.05
seed: int, optional
Used to seed the random generator. Default: ``None``
initialized: bool, optional
This should always be set to ``False``. To avoid Keras re-calculating the values every time
the model is loaded, this parameter is internally set on first time initialization.
Default:``False``
Returns
-------
tensor
The modified kernel weights
References
----------
Armen Aghajanyan, https://arxiv.org/abs/1702.06295
"""
def __init__(self, eps_std=0.05, seed=None, initialized=False):
self.eps_std = eps_std
self.seed = seed
self.orthogonal = initializers.Orthogonal()
self.he_uniform = initializers.he_uniform()
self.initialized = initialized
def __call__(self, shape, dtype=None):
""" Call function for the ICNR initializer.
Parameters
----------
shape: tuple or list
The required shape for the output tensor
dtype: str
The data type for the tensor
Returns
-------
tensor
The modified kernel weights
"""
# TODO Tensorflow appears to pass in a :class:`tensorflow.python.framework.dtypes.DType`
# object which causes this to error, so currently just reverts to default dtype if a string
# is not passed in.
if self.initialized: # Avoid re-calculating initializer when loading a saved model
return self.he_uniform(shape, dtype=dtype)
dtype = K.floatx() if not isinstance(dtype, str) else dtype
logger.info("Calculating Convolution Aware Initializer for shape: %s", shape)
rank = len(shape)
if self.seed is not None:
np.random.seed(self.seed)
fan_in, _ = compute_fans(shape) # pylint:disable=protected-access
variance = 2 / fan_in
if rank == 3:
row, stack_size, filters_size = shape
transpose_dimensions = (2, 1, 0)
kernel_shape = (row,)
correct_ifft = lambda shape, s=[None]: np.fft.irfft(shape, s[0]) # noqa
correct_fft = np.fft.rfft
elif rank == 4:
row, column, stack_size, filters_size = shape
transpose_dimensions = (2, 3, 1, 0)
kernel_shape = (row, column)
correct_ifft = np.fft.irfft2
correct_fft = np.fft.rfft2
elif rank == 5:
var_x, var_y, var_z, stack_size, filters_size = shape
transpose_dimensions = (3, 4, 0, 1, 2)
kernel_shape = (var_x, var_y, var_z)
correct_fft = np.fft.rfftn
correct_ifft = np.fft.irfftn
else:
self.initialized = True
return K.variable(self.orthogonal(shape), dtype=dtype)
kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape
basis = self._create_basis(filters_size, stack_size, np.prod(kernel_fourier_shape), dtype)
basis = basis.reshape((filters_size, stack_size,) + kernel_fourier_shape)
randoms = np.random.normal(0, self.eps_std, basis.shape[:-2] + kernel_shape)
init = correct_ifft(basis, kernel_shape) + randoms
init = self._scale_filters(init, variance)
self.initialized = True
return K.variable(init.transpose(transpose_dimensions), dtype=dtype, name="conv_aware")
def _create_basis(self, filters_size, filters, size, dtype):
""" Create the basis for convolutional aware initialization """
logger.debug("filters_size: %s, filters: %s, size: %s, dtype: %s",
filters_size, filters, size, dtype)
if size == 1:
return np.random.normal(0.0, self.eps_std, (filters_size, filters, size))
nbb = filters // size + 1
var_a = np.random.normal(0.0, 1.0, (filters_size, nbb, size, size))
var_a = self._symmetrize(var_a)
var_u = np.linalg.svd(var_a)[0].transpose(0, 1, 3, 2)
var_p = np.reshape(var_u, (filters_size, nbb * size, size))[:, :filters, :].astype(dtype)
return var_p
@staticmethod
def _symmetrize(var_a):
""" Make the given tensor symmetrical. """
var_b = np.transpose(var_a, axes=(0, 1, 3, 2))
diag = var_a.diagonal(axis1=2, axis2=3)
var_c = np.array([[np.diag(arr) for arr in batch] for batch in diag])
return var_a + var_b - var_c
@staticmethod
def _scale_filters(filters, variance):
""" Scale the given filters. """
c_var = np.var(filters)
var_p = np.sqrt(variance / c_var)
return filters * var_p
def get_config(self):
""" Return the Convolutional Aware Initializer configuration.
Returns
-------
dict
The configuration for ICNR Initialization
"""
return dict(eps_std=self.eps_std,
seed=self.seed,
initialized=self.initialized)
# Update initializers into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})