1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/lib/model/normalization.py
torzdf d8557c1970
Faceswap 2.0 (#1045)
* Core Updates
    - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
    - Document lib.gpu_stats and lib.sys_info
    - Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
    - lib.gui.menu - typofix

* Update Dependencies
Bump Tensorflow Version Check

* Port extraction to tf2

* Add custom import finder for loading Keras or tf.keras depending on backend

* Add `tensorflow` to KerasFinder search path

* Basic TF2 training running

* model.initializers - docstring fix

* Fix and pass tests for tf2

* Replace Keras backend tests with faceswap backend tests

* Initial optimizers update

* Monkey patch tf.keras optimizer

* Remove custom Adam Optimizers and Memory Saving Gradients

* Remove multi-gpu option. Add Distribution to cli

* plugins.train.model._base: Add Mirror, Central and Default distribution strategies

* Update tensorboard kwargs for tf2

* Penalized Loss - Fix for TF2 and AMD

* Fix syntax for tf2.1

* requirements typo fix

* Explicit None for clipnorm if using a distribution strategy

* Fix penalized loss for distribution strategies

* Update Dlight

* typo fix

* Pin to TF2.2

* setup.py - Install tensorflow from pip if not available in Conda

* Add reduction options and set default for mirrored distribution strategy

* Explicitly use default strategy rather than nullcontext

* lib.model.backup_restore documentation

* Remove mirrored strategy reduction method and default based on OS

* Initial restructure - training

* Remove PingPong
Start model.base refactor

* Model saving and resuming enabled

* More tidying up of model.base

* Enable backup and snapshotting

* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly

* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers

* Apply custom Conv2D layer

* Finalize NNBlock restructure
Update Dfaker blocks

* Fix reloading model under a different distribution strategy

* Pass command line arguments through to trainer

* Remove training_opts from model and reference params directly

* Tidy up model __init__

* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning

* Fix timelapse

* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original

* dfl-h128 ported

* DFL SAE ported

* IAE Ported

* dlight ported

* port lightweight

* realface ported

* unbalanced ported

* villain ported

* lib.cli.args - Update Batchsize + move allow_growth to config

* Remove output shape definition
Get image sizes per side rather than globally

* Strip mask input from encoder

* Fix learn mask and output learned mask to preview

* Trigger Allow Growth prior to setting strategy

* Fix GUI Graphing

* GUI - Display batchsize correctly + fix training graphs

* Fix penalized loss

* Enable mixed precision training

* Update analysis displayed batch to match input

* Penalized Loss - Multi-GPU Fix

* Fix all losses for TF2

* Fix Reflect Padding

* Allow different input size for each side of the model

* Fix conv-aware initialization on reload

* Switch allow_growth order

* Move mixed_precision to cli

* Remove distrubution strategies

* Compile penalized loss sub-function into LossContainer

* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels

* Add ability to refresh preview on demand on pop-up window

* Enable refresh of training preview from GUI

* Fix Convert
Debug logging in Initializers

* Fix Preview Tool

* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs

* lib.gui.popup_configure - Make more responsive + document

* Multiple Outputs supported in trainer
Original Model - Mask output bugfix

* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)

* Fix inference model to work properly with all models

* Fix multi-scale output for convert

* Fix clipnorm issue with distribution strategies
Edit error message on OOM

* Update plaidml losses

* Add missing file

* Disable gmsd loss for plaidnl

* PlaidML - Basic training working

* clipnorm rewriting for mixed-precision

* Inference model creation bugfixes

* Remove debug code

* Bugfix: Default clipnorm to 1.0

* Remove all mask inputs from training code

* Remove mask inputs from convert

* GUI - Analysis Tab - Docstrings

* Fix rate in totals row

* lib.gui - Only update display pages if they have focus

* Save the model on first iteration

* plaidml - Fix SSIM loss with penalized loss

* tools.alignments - Remove manual and fix jobs

* GUI - Remove case formatting on help text

* gui MultiSelect custom widget - Set default values on init

* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded

* Cascade excluded devices to GPU Stats

* Explicit GPU selection for Train and Convert

* Reduce Tensorflow Min GPU Multiprocessor Count to 4

* remove compat.v1 code from extract

* Force TF to skip mixed precision compatibility check if GPUs have been filtered

* Add notes to config for non-working AMD losses

* Rasie error if forcing extract to CPU mode

* Fix loading of legace dfl-sae weights + dfl-sae typo fix

* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations

* docs: lib.gui.display

* clipnorm amd condition check

* documentation - gui.display_analysis

* Documentation - gui.popup_configure

* Documentation - lib.logger

* Documentation - lib.model.initializers

* Documentation - lib.model.layers

* Documentation - lib.model.losses

* Documentation - lib.model.nn_blocks

* Documetation - lib.model.normalization

* Documentation - lib.model.session

* Documentation - lib.plaidml_stats

* Documentation: lib.training_data

* Documentation: lib.utils

* Documentation: plugins.train.model._base

* GUI Stats: prevent stats from using GPU

* Documentation - Original Model

* Documentation: plugins.model.trainer._base

* linting

* unit tests: initializers + losses

* unit tests: nn_blocks

* bugfix - Exclude gpu devices in train, not include

* Enable Exclude-Gpus in Extract

* Enable exclude gpus in tools

* Disallow multiple plugin types in a single model folder

* Automatically add exclude_gpus argument in for cpu backends

* Cpu backend fixes

* Relax optimizer test threshold

* Default Train settings - Set mask to Extended

* Update Extractor cli help text
Update to Python 3.8

* Fix FAN to run on CPU

* lib.plaidml_tools - typofix

* Linux installer - check for curl

* linux installer - typo fix
2020-08-12 10:36:41 +01:00

197 lines
7.6 KiB
Python

#!/usr/bin/env python3
""" Normalization methods for faceswap.py. """
import sys
import inspect
from keras.layers import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils import get_custom_objects
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation standard deviation close to 1.
Parameters
----------
axis: int, optional
The axis that should be normalized (typically the features axis). For instance, after a
`Conv2D` layer with `data_format="channels_first"`, set `axis=1` in
:class:`InstanceNormalization`. Setting `axis=None` will normalize all values in each
instance of the batch. Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid
errors. Default: ``None``
epsilon: float, optional
Small float added to variance to avoid dividing by zero. Default: `1e-3`
center: bool, optional
If ``True``, add offset of `beta` to normalized tensor. If ``False``, `beta` is ignored.
Default: ``True``
scale: bool, optional
If ``True``, multiply by `gamma`. If ``False``, `gamma` is not used. When the next layer
is linear (also e.g. `relu`), this can be disabled since the scaling will be done by
the next layer. Default: ``True``
beta_initializer: str, optional
Initializer for the beta weight. Default: `"zeros"`
gamma_initializer: str, optional
Initializer for the gamma weight. Default: `"ones"`
beta_regularizer: str, optional
Optional regularizer for the beta weight. Default: ``None``
gamma_regularizer: str, optional
Optional regularizer for the gamma weight. Default: ``None``
beta_constraint: float, optional
Optional constraint for the beta weight. Default: ``None``
gamma_constraint: float, optional
Optional constraint for the gamma weight. Default: ``None``
References
----------
- Layer Normalization - https://arxiv.org/abs/1607.06450
- Instance Normalization: The Missing Ingredient for Fast Stylization - \
https://arxiv.org/abs/1607.08022
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
self.beta = None
self.gamma = None
super().__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
"""Creates the layer weights.
Must be implemented on all layers that have weights.
Parameters
----------
input_shape: tensor
Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to
reference for weight shape computations.
"""
ndim = len(input_shape)
if self.axis == 0:
raise ValueError("Axis cannot be zero")
if (self.axis is not None) and (ndim == 2):
raise ValueError("Cannot specify axis for rank 1 tensor")
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None): # pylint:disable=arguments-differ,unused-argument
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
Returns
-------
tensor
A tensor or list/tuple of tensors
"""
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if self.axis is not None:
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary (serializable) containing the configuration of a
layer. The same layer can be reinstated later (without its trained weights) from this
configuration.
The configuration of a layer does not include connectivity information, nor the layer
class name. These are handled by `Network` (one layer of abstraction above).
Returns
--------
dict
A python dictionary containing the layer configuration
"""
config = {
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": initializers.serialize(self.beta_initializer),
"gamma_initializer": initializers.serialize(self.gamma_initializer),
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
"beta_constraint": constraints.serialize(self.beta_constraint),
"gamma_constraint": constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Update normalization into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})