#!/usr/bin/env python3 """ Normaliztion methods for faceswap.py Code from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" import sys import inspect from keras.engine import Layer, InputSpec from keras import initializers, regularizers, constraints from keras import backend as K from keras.utils.generic_utils import get_custom_objects def to_list(inp): """ Convert to list """ if not isinstance(inp, (list, tuple)): return [inp] return list(inp) class InstanceNormalization(Layer): """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). Normalize the activations of the previous layer at each step, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. # Arguments axis: Integer, the axis that should be normalized (typically the features axis). For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `InstanceNormalization`. Setting `axis=None` will normalize all values in each instance of the batch. Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling will be done by the next layer. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. beta_constraint: Optional constraint for the beta weight. gamma_constraint: Optional constraint for the gamma weight. # Input shape Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. # Output shape Same shape as input. # References - [Layer Normalization](https://arxiv.org/abs/1607.06450) - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) """ def __init__(self, axis=None, epsilon=1e-3, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs): self.beta = None self.gamma = None super(InstanceNormalization, self).__init__(**kwargs) self.supports_masking = True self.axis = axis self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) def build(self, input_shape): ndim = len(input_shape) if self.axis == 0: raise ValueError('Axis cannot be zero') if (self.axis is not None) and (ndim == 2): raise ValueError('Cannot specify axis for rank 1 tensor') self.input_spec = InputSpec(ndim=ndim) if self.axis is None: shape = (1,) else: shape = (input_shape[self.axis],) if self.scale: self.gamma = self.add_weight(shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) else: self.gamma = None if self.center: self.beta = self.add_weight(shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) else: self.beta = None self.built = True def call(self, inputs, training=None): input_shape = K.int_shape(inputs) reduction_axes = list(range(0, len(input_shape))) if self.axis is not None: del reduction_axes[self.axis] del reduction_axes[0] mean = K.mean(inputs, reduction_axes, keepdims=True) stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon normed = (inputs - mean) / stddev broadcast_shape = [1] * len(input_shape) if self.axis is not None: broadcast_shape[self.axis] = input_shape[self.axis] if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) normed = normed * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) normed = normed + broadcast_beta return normed def get_config(self): config = { 'axis': self.axis, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint) } base_config = super(InstanceNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) class GroupNormalization(Layer): """ Group Normalization from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" def __init__(self, axis=-1, gamma_init='one', beta_init='zero', gamma_regularizer=None, beta_regularizer=None, epsilon=1e-6, group=32, data_format=None, **kwargs): self.beta = None self.gamma = None super(GroupNormalization, self).__init__(**kwargs) self.axis = to_list(axis) self.gamma_init = initializers.get(gamma_init) self.beta_init = initializers.get(beta_init) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_regularizer = regularizers.get(beta_regularizer) self.epsilon = epsilon self.group = group self.data_format = K.normalize_data_format(data_format) self.supports_masking = True def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] shape = [1 for _ in input_shape] if self.data_format == 'channels_last': channel_axis = -1 shape[channel_axis] = input_shape[channel_axis] elif self.data_format == 'channels_first': channel_axis = 1 shape[channel_axis] = input_shape[channel_axis] # for i in self.axis: # shape[i] = input_shape[i] self.gamma = self.add_weight(shape=shape, initializer=self.gamma_init, regularizer=self.gamma_regularizer, name='gamma') self.beta = self.add_weight(shape=shape, initializer=self.beta_init, regularizer=self.beta_regularizer, name='beta') self.built = True def call(self, inputs, mask=None): input_shape = K.int_shape(inputs) if len(input_shape) != 4 and len(input_shape) != 2: raise ValueError('Inputs should have rank ' + str(4) + " or " + str(2) + '; Received input shape:', str(input_shape)) if len(input_shape) == 4: if self.data_format == 'channels_last': batch_size, height, width, channels = input_shape if batch_size is None: batch_size = -1 if channels < self.group: raise ValueError('Input channels should be larger than group size' + '; Received input channels: ' + str(channels) + '; Group size: ' + str(self.group)) var_x = K.reshape(inputs, (batch_size, height, width, self.group, channels // self.group)) mean = K.mean(var_x, axis=[1, 2, 4], keepdims=True) std = K.sqrt(K.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon) var_x = (var_x - mean) / std var_x = K.reshape(var_x, (batch_size, height, width, channels)) retval = self.gamma * var_x + self.beta elif self.data_format == 'channels_first': batch_size, channels, height, width = input_shape if batch_size is None: batch_size = -1 if channels < self.group: raise ValueError('Input channels should be larger than group size' + '; Received input channels: ' + str(channels) + '; Group size: ' + str(self.group)) var_x = K.reshape(inputs, (batch_size, self.group, channels // self.group, height, width)) mean = K.mean(var_x, axis=[2, 3, 4], keepdims=True) std = K.sqrt(K.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon) var_x = (var_x - mean) / std var_x = K.reshape(var_x, (batch_size, channels, height, width)) retval = self.gamma * var_x + self.beta elif len(input_shape) == 2: reduction_axes = list(range(0, len(input_shape))) del reduction_axes[0] batch_size, _ = input_shape if batch_size is None: batch_size = -1 mean = K.mean(inputs, keepdims=True) std = K.sqrt(K.var(inputs, keepdims=True) + self.epsilon) var_x = (inputs - mean) / std retval = self.gamma * var_x + self.beta return retval def get_config(self): config = {'epsilon': self.epsilon, 'axis': self.axis, 'gamma_init': initializers.serialize(self.gamma_init), 'beta_init': initializers.serialize(self.beta_init), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_regularizer': regularizers.serialize(self.gamma_regularizer), 'group': self.group} base_config = super(GroupNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) # Update normalizations into Keras custom objects for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and obj.__module__ == __name__: get_custom_objects().update({name: obj})