1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/lib/model/normalization/normalization_plaid.py
2021-02-17 00:24:51 +00:00

384 lines
14 KiB
Python

#!/usr/bin/env python3
""" Normalization methods for faceswap.py. """
import sys
import inspect
from plaidml.op import slice_tensor
from keras.layers import Layer
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils import get_custom_objects
class LayerNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016). Implementation adapted from
tensorflow.keras implementation and https://github.com/CyberZHG/keras-layer-normalization
Normalize the activations of the previous layer for each given example in a batch
independently, rather than across a batch like Batch Normalization. i.e. applies a
transformation that maintains the mean activation within each example close to 0 and the
activation standard deviation close to 1.
Parameters
----------
axis: int or list/tuple
The axis or axes to normalize across. Typically this is the features axis/axes.
The left-out axes are typically the batch axis/axes. This argument defaults to `-1`, the
last dimension in the input.
epsilon: float, optional
Small float added to variance to avoid dividing by zero. Default: `1e-3`
center: bool, optional
If ``True``, add offset of `beta` to normalized tensor. If ``False``, `beta` is ignored.
Default: ``True``
scale: bool, optional
If ``True``, multiply by `gamma`. If ``False``, `gamma` is not used. When the next layer
is linear (also e.g. `relu`), this can be disabled since the scaling will be done by
the next layer. Default: ``True``
beta_initializer: str, optional
Initializer for the beta weight. Default: `"zeros"`
gamma_initializer: str, optional
Initializer for the gamma weight. Default: `"ones"`
beta_regularizer: str, optional
Optional regularizer for the beta weight. Default: ``None``
gamma_regularizer: str, optional
Optional regularizer for the gamma weight. Default: ``None``
beta_constraint: float, optional
Optional constraint for the beta weight. Default: ``None``
gamma_constraint: float, optional
Optional constraint for the gamma weight. Default: ``None``
kwargs: dict
Standard keras layer kwargs
References
----------
- Layer Normalization - https://arxiv.org/abs/1607.06450
- Keras implementation - https://github.com/CyberZHG/keras-layer-normalization
"""
def __init__(self,
axis=-1,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
self.gamma = None
self.beta = None
super().__init__(**kwargs)
if isinstance(axis, (list, tuple)):
self.axis = axis[:]
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError("Expected an int or a list/tuple of ints for the argument 'axis', "
f"but received: {axis}")
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
self.supports_masking = True
def build(self, input_shape):
"""Creates the layer weights.
Parameters
----------
input_shape: tensor
Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to
reference for weight shape computations.
"""
ndims = len(input_shape)
if ndims is None:
raise ValueError(f"Input shape {input_shape} has undefined rank.")
# Convert axis to list and resolve negatives
if isinstance(self.axis, int):
self.axis = [self.axis]
elif isinstance(self.axis, tuple):
self.axis = list(self.axis)
for idx, axs in enumerate(self.axis):
if axs < 0:
self.axis[idx] = ndims + axs
# Validate axes
for axs in self.axis:
if axs < 0 or axs >= ndims:
raise ValueError(f"Invalid axis: {axs}")
if len(self.axis) != len(set(self.axis)):
raise ValueError("Duplicate axis: {}".format(tuple(self.axis)))
param_shape = [input_shape[dim] for dim in self.axis]
if self.scale:
self.gamma = self.add_weight(
name="gamma",
shape=param_shape,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
if self.center:
self.beta = self.add_weight(
name='beta',
shape=param_shape,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
self.built = True # pylint:disable=attribute-defined-outside-init
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
"""This is where the layer's logic lives.
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
Returns
-------
tensor
A tensor or list/tuple of tensors
"""
# Compute the axes along which to reduce the mean / variance
input_shape = K.int_shape(inputs)
ndims = len(input_shape)
# Broadcasting only necessary for norm when the axis is not just the last dimension
broadcast_shape = [1] * ndims
for dim in self.axis:
broadcast_shape[dim] = input_shape[dim]
def _broadcast(var):
if (var is not None and len(var.shape) != ndims and self.axis != [ndims - 1]):
return K.reshape(var, broadcast_shape)
return var
# Calculate the moments on the last axis (layer activations).
mean = K.mean(inputs, self.axis, keepdims=True)
variance = K.mean(K.square(inputs - mean), axis=self.axis, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if self.scale:
outputs *= scale
if self.center:
outputs *= offset
return outputs
def compute_output_shape(self, input_shape): # pylint:disable=no-self-use
""" The output shape of the layer is the same as the input shape.
Parameters
----------
input_shape: tuple
The input shape to the layer
Returns
-------
tuple
The output shape to the layer
"""
return input_shape
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary (serializable) containing the configuration of a
layer. The same layer can be reinstated later (without its trained weights) from this
configuration.
The configuration of a layer does not include connectivity information, nor the layer
class name. These are handled by `Network` (one layer of abstraction above).
Returns
--------
dict
A python dictionary containing the layer configuration
"""
base_config = super().get_config()
config = dict(axis=self.axis,
epsilon=self.epsilon,
center=self.center,
scale=self.scale,
beta_initializer=initializers.serialize(self.beta_initializer),
gamma_initializer=initializers.serialize(self.gamma_initializer),
beta_regularizer=regularizers.serialize(self.beta_regularizer),
gamma_regularizer=regularizers.serialize(self.gamma_regularizer),
beta_constraint=constraints.serialize(self.beta_constraint),
gamma_constraint=constraints.serialize(self.gamma_constraint))
return dict(list(base_config.items()) + list(config.items()))
class RMSNormalization(Layer):
""" Root Mean Square Layer Normalization (Biao Zhang, Rico Sennrich, 2019)
RMSNorm is a simplification of the original layer normalization (LayerNorm). LayerNorm is a
regularization technique that might handle the internal covariate shift issue so as to
stabilize the layer activations and improve model convergence. It has been proved quite
successful in NLP-based model. In some cases, LayerNorm has become an essential component
to enable model optimization, such as in the SOTA NMT model Transformer.
RMSNorm simplifies LayerNorm by removing the mean-centering operation, or normalizing layer
activations with RMS statistic.
Parameters
----------
axis: int
The axis to normalize across. Typically this is the features axis. The left-out axes are
typically the batch axis/axes. This argument defaults to `-1`, the last dimension in the
input.
epsilon: float, optional
Small float added to variance to avoid dividing by zero. Default: `1e-8`
partial: float, optional
Partial multiplier for calculating pRMSNorm. Valid values are between `0.0` and `1.0`.
Setting to `0.0` or `1.0` disables. Default: `0.0`
bias: bool, optional
Whether to use a bias term for RMSNorm. Disabled by default because RMSNorm does not
enforce re-centering invariance. Default ``False``
kwargs: dict
Standard keras layer kwargs
References
----------
- RMS Normalization - https://arxiv.org/abs/1910.07467
- Official implementation - https://github.com/bzhangGo/rmsnorm
"""
def __init__(self, axis=-1, epsilon=1e-8, partial=0.0, bias=False, **kwargs):
self.scale = None
self.offset = 0
super().__init__(**kwargs)
# Checks
if not isinstance(axis, int):
raise TypeError(f"Expected an int for the argument 'axis', but received: {axis}")
if not 0.0 <= partial <= 1.0:
raise ValueError(f"partial must be between 0.0 and 1.0, but received {partial}")
self.axis = axis
self.epsilon = epsilon
self.partial = partial
self.bias = bias
self.offset = 0.
def build(self, input_shape):
""" Validate and populate :attr:`axis`
Parameters
----------
input_shape: tensor
Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to
reference for weight shape computations.
"""
ndims = len(input_shape)
if ndims is None:
raise ValueError(f"Input shape {input_shape} has undefined rank.")
# Resolve negative axis
if self.axis < 0:
self.axis += ndims
# Validate axes
if self.axis < 0 or self.axis >= ndims:
raise ValueError(f"Invalid axis: {self.axis}")
param_shape = [input_shape[self.axis]]
self.scale = self.add_weight(
name="scale",
shape=param_shape,
initializer="ones")
if self.bias:
self.offset = self.add_weight(
name="offset",
shape=param_shape,
initializer="zeros")
self.built = True # pylint:disable=attribute-defined-outside-init
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
""" Call Root Mean Square Layer Normalization
Parameters
----------
inputs: tensor
Input tensor, or list/tuple of input tensors
Returns
-------
tensor
A tensor or list/tuple of tensors
"""
# Compute the axes along which to reduce the mean / variance
input_shape = K.int_shape(inputs)
layer_size = input_shape[self.axis]
if self.partial in (0.0, 1.0):
mean_square = K.mean(K.square(inputs), axis=self.axis, keepdims=True)
else:
partial_size = int(layer_size * self.partial)
partial_x = slice_tensor(inputs,
axes=[self.axis],
starts=[0],
ends=[partial_size])
mean_square = K.mean(K.square(partial_x), axis=self.axis, keepdims=True)
recip_square_root = 1. / K.sqrt(mean_square + self.epsilon)
output = self.scale * inputs * recip_square_root + self.offset
return output
def compute_output_shape(self, input_shape): # pylint:disable=no-self-use
""" The output shape of the layer is the same as the input shape.
Parameters
----------
input_shape: tuple
The input shape to the layer
Returns
-------
tuple
The output shape to the layer
"""
return input_shape
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary (serializable) containing the configuration of a
layer. The same layer can be reinstated later (without its trained weights) from this
configuration.
The configuration of a layer does not include connectivity information, nor the layer
class name. These are handled by `Network` (one layer of abstraction above).
Returns
--------
dict
A python dictionary containing the layer configuration
"""
base_config = super().get_config()
config = dict(axis=self.axis,
epsilon=self.epsilon,
partial=self.partial,
bias=self.bias)
return dict(list(base_config.items()) + list(config.items()))
# Update normalization into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})