mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Update PixelShuffler.py Keras 2.2.2 moves normalize_data_format to keras.backend from conv_utils * update Keras to 2.2.2 in requirements.txt
88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
# PixelShuffler layer for Keras
|
|
# by t-ae
|
|
# https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
|
|
|
|
from keras.utils import conv_utils
|
|
from keras.engine.topology import Layer
|
|
import keras.backend as K
|
|
|
|
|
|
class PixelShuffler(Layer):
|
|
def __init__(self, size=(2, 2), data_format=None, **kwargs):
|
|
super(PixelShuffler, self).__init__(**kwargs)
|
|
self.data_format = K.normalize_data_format(data_format)
|
|
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
|
|
|
def call(self, inputs):
|
|
|
|
input_shape = K.int_shape(inputs)
|
|
if len(input_shape) != 4:
|
|
raise ValueError('Inputs should have rank ' +
|
|
str(4) +
|
|
'; Received input shape:', str(input_shape))
|
|
|
|
if self.data_format == 'channels_first':
|
|
batch_size, c, h, w = input_shape
|
|
if batch_size is None:
|
|
batch_size = -1
|
|
rh, rw = self.size
|
|
oh, ow = h * rh, w * rw
|
|
oc = c // (rh * rw)
|
|
|
|
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
|
|
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
|
|
out = K.reshape(out, (batch_size, oc, oh, ow))
|
|
return out
|
|
|
|
elif self.data_format == 'channels_last':
|
|
batch_size, h, w, c = input_shape
|
|
if batch_size is None:
|
|
batch_size = -1
|
|
rh, rw = self.size
|
|
oh, ow = h * rh, w * rw
|
|
oc = c // (rh * rw)
|
|
|
|
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
|
|
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
|
out = K.reshape(out, (batch_size, oh, ow, oc))
|
|
return out
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
|
|
if len(input_shape) != 4:
|
|
raise ValueError('Inputs should have rank ' +
|
|
str(4) +
|
|
'; Received input shape:', str(input_shape))
|
|
|
|
if self.data_format == 'channels_first':
|
|
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
|
|
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
|
|
channels = input_shape[1] // self.size[0] // self.size[1]
|
|
|
|
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
|
raise ValueError('channels of input and size are incompatible')
|
|
|
|
return (input_shape[0],
|
|
channels,
|
|
height,
|
|
width)
|
|
|
|
elif self.data_format == 'channels_last':
|
|
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
|
|
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
|
|
channels = input_shape[3] // self.size[0] // self.size[1]
|
|
|
|
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
|
raise ValueError('channels of input and size are incompatible')
|
|
|
|
return (input_shape[0],
|
|
height,
|
|
width,
|
|
channels)
|
|
|
|
def get_config(self):
|
|
config = {'size': self.size,
|
|
'data_format': self.data_format}
|
|
base_config = super(PixelShuffler, self).get_config()
|
|
|
|
return dict(list(base_config.items()) + list(config.items()))
|