mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-08 03:26:47 -04:00
* Clearer requirements for each platform
* Refactoring of old plugins (Model_Original + Extract_Align) + Cleanups
* Adding GAN128
* Update GAN to v2
* Create instance_normalization.py
* Fix decoder output
* Revert "Fix decoder output"
This reverts commit 3a8ecb8957
.
* Fix convert
* Enable all options except perceptual_loss by default
* Disable instance norm
* Update Model.py
* Update Trainer.py
* Match GAN128 to shaoanlu's latest v2
* Add first_order to GAN128
* Disable `use_perceptual_loss`
* Fix call to `self.first_order`
* Switch to average loss in output
* Constrain average to last 100 iterations
* Fix math, constrain average to intervals of 100
* Fix math averaging again
* Remove math and simplify this damn averagin
* Add gan128 conversion
* Update convert.py
* Use non-warped images in masked preview
* Add K.set_learning_phase(1) to gan64
* Add K.set_learning_phase(1) to gan128
* Add missing keras import
* Use non-warped images in masked preview for gan128
* Exclude deleted faces from conversion
* --input-aligned-dir defaults to "{input_dir}/aligned"
* Simplify map operation
* port 'face_alignment' from PyTorch to Keras. It works x2 faster, but initialization takes 20secs.
2DFAN-4.h5 and mmod_human_face_detector.dat included in lib\FaceLandmarksExtractor
fixed dlib vs tensorflow conflict: dlib must do op first, then load keras model, otherwise CUDA OOM error
if face location not found by CNN, its try to find by HOG.
removed this:
- if face.landmarks == None:
- print("Warning! landmarks not found. Switching to crop!")
- return cv2.resize(face.image, (size, size))
because DetectedFace always has landmarks
* Enabled masked converter for GAN models
* Histogram matching, cli option for perceptual loss
* Fix init() positional args error
* Add backwards compatibility for aligned filenames
* Fix masked converter
* Remove GAN converters
145 lines
6.2 KiB
Python
145 lines
6.2 KiB
Python
from keras.engine import Layer, InputSpec
|
|
from keras import initializers, regularizers, constraints
|
|
from keras import backend as K
|
|
from keras.utils.generic_utils import get_custom_objects
|
|
|
|
import numpy as np
|
|
|
|
|
|
class InstanceNormalization(Layer):
|
|
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
|
Normalize the activations of the previous layer at each step,
|
|
i.e. applies a transformation that maintains the mean activation
|
|
close to 0 and the activation standard deviation close to 1.
|
|
# Arguments
|
|
axis: Integer, the axis that should be normalized
|
|
(typically the features axis).
|
|
For instance, after a `Conv2D` layer with
|
|
`data_format="channels_first"`,
|
|
set `axis=1` in `InstanceNormalization`.
|
|
Setting `axis=None` will normalize all values in each instance of the batch.
|
|
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
|
epsilon: Small float added to variance to avoid dividing by zero.
|
|
center: If True, add offset of `beta` to normalized tensor.
|
|
If False, `beta` is ignored.
|
|
scale: If True, multiply by `gamma`.
|
|
If False, `gamma` is not used.
|
|
When the next layer is linear (also e.g. `nn.relu`),
|
|
this can be disabled since the scaling
|
|
will be done by the next layer.
|
|
beta_initializer: Initializer for the beta weight.
|
|
gamma_initializer: Initializer for the gamma weight.
|
|
beta_regularizer: Optional regularizer for the beta weight.
|
|
gamma_regularizer: Optional regularizer for the gamma weight.
|
|
beta_constraint: Optional constraint for the beta weight.
|
|
gamma_constraint: Optional constraint for the gamma weight.
|
|
# Input shape
|
|
Arbitrary. Use the keyword argument `input_shape`
|
|
(tuple of integers, does not include the samples axis)
|
|
when using this layer as the first layer in a model.
|
|
# Output shape
|
|
Same shape as input.
|
|
# References
|
|
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
|
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
|
"""
|
|
def __init__(self,
|
|
axis=None,
|
|
epsilon=1e-3,
|
|
center=True,
|
|
scale=True,
|
|
beta_initializer='zeros',
|
|
gamma_initializer='ones',
|
|
beta_regularizer=None,
|
|
gamma_regularizer=None,
|
|
beta_constraint=None,
|
|
gamma_constraint=None,
|
|
**kwargs):
|
|
super(InstanceNormalization, self).__init__(**kwargs)
|
|
self.supports_masking = True
|
|
self.axis = axis
|
|
self.epsilon = epsilon
|
|
self.center = center
|
|
self.scale = scale
|
|
self.beta_initializer = initializers.get(beta_initializer)
|
|
self.gamma_initializer = initializers.get(gamma_initializer)
|
|
self.beta_regularizer = regularizers.get(beta_regularizer)
|
|
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
|
self.beta_constraint = constraints.get(beta_constraint)
|
|
self.gamma_constraint = constraints.get(gamma_constraint)
|
|
|
|
def build(self, input_shape):
|
|
ndim = len(input_shape)
|
|
if self.axis == 0:
|
|
raise ValueError('Axis cannot be zero')
|
|
|
|
if (self.axis is not None) and (ndim == 2):
|
|
raise ValueError('Cannot specify axis for rank 1 tensor')
|
|
|
|
self.input_spec = InputSpec(ndim=ndim)
|
|
|
|
if self.axis is None:
|
|
shape = (1,)
|
|
else:
|
|
shape = (input_shape[self.axis],)
|
|
|
|
if self.scale:
|
|
self.gamma = self.add_weight(shape=shape,
|
|
name='gamma',
|
|
initializer=self.gamma_initializer,
|
|
regularizer=self.gamma_regularizer,
|
|
constraint=self.gamma_constraint)
|
|
else:
|
|
self.gamma = None
|
|
if self.center:
|
|
self.beta = self.add_weight(shape=shape,
|
|
name='beta',
|
|
initializer=self.beta_initializer,
|
|
regularizer=self.beta_regularizer,
|
|
constraint=self.beta_constraint)
|
|
else:
|
|
self.beta = None
|
|
self.built = True
|
|
|
|
def call(self, inputs, training=None):
|
|
input_shape = K.int_shape(inputs)
|
|
reduction_axes = list(range(0, len(input_shape)))
|
|
|
|
if (self.axis is not None):
|
|
del reduction_axes[self.axis]
|
|
|
|
del reduction_axes[0]
|
|
|
|
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
|
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
|
normed = (inputs - mean) / stddev
|
|
|
|
broadcast_shape = [1] * len(input_shape)
|
|
if self.axis is not None:
|
|
broadcast_shape[self.axis] = input_shape[self.axis]
|
|
|
|
if self.scale:
|
|
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
|
normed = normed * broadcast_gamma
|
|
if self.center:
|
|
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
|
normed = normed + broadcast_beta
|
|
return normed
|
|
|
|
def get_config(self):
|
|
config = {
|
|
'axis': self.axis,
|
|
'epsilon': self.epsilon,
|
|
'center': self.center,
|
|
'scale': self.scale,
|
|
'beta_initializer': initializers.serialize(self.beta_initializer),
|
|
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
|
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
|
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
|
'beta_constraint': constraints.serialize(self.beta_constraint),
|
|
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
|
}
|
|
base_config = super(InstanceNormalization, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|