mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
92 lines
4.1 KiB
Python
92 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
""" Original - VillainGuy model
|
|
Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
|
Adapted from a model by VillainGuy (https://github.com/VillainGuy) """
|
|
|
|
from keras.initializers import RandomNormal
|
|
from keras.layers import add, Dense, Flatten, Input, Reshape
|
|
from keras.models import Model as KerasModel
|
|
|
|
from lib.model.layers import PixelShuffler
|
|
from .original import logger, Model as OriginalModel
|
|
|
|
|
|
class Model(OriginalModel):
|
|
""" Villain Faceswap Model """
|
|
def __init__(self, *args, **kwargs):
|
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
|
self.__class__.__name__, args, kwargs)
|
|
|
|
self.configfile = kwargs.get("configfile", None)
|
|
kwargs["input_shape"] = (128, 128, 3)
|
|
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
|
|
self.kernel_initializer = RandomNormal(0, 0.02)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def encoder(self):
|
|
""" Encoder Network """
|
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
|
input_ = Input(shape=self.input_shape)
|
|
in_conv_filters = self.input_shape[0]
|
|
if self.input_shape[0] > 128:
|
|
in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
|
|
dense_shape = self.input_shape[0] // 16
|
|
|
|
var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs)
|
|
tmp_x = var_x
|
|
res_cycles = 8 if self.config.get("lowmem", False) else 16
|
|
for _ in range(res_cycles):
|
|
nn_x = self.blocks.res_block(var_x, in_conv_filters, **kwargs)
|
|
var_x = nn_x
|
|
# consider adding scale before this layer to scale the residual chain
|
|
var_x = add([var_x, tmp_x])
|
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
|
var_x = PixelShuffler()(var_x)
|
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
|
var_x = PixelShuffler()(var_x)
|
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
|
var_x = self.blocks.conv_sep(var_x, 256, **kwargs)
|
|
var_x = self.blocks.conv(var_x, 512, **kwargs)
|
|
if not self.config.get("lowmem", False):
|
|
var_x = self.blocks.conv_sep(var_x, 1024, **kwargs)
|
|
|
|
var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
|
|
var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
|
|
var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
|
|
var_x = self.blocks.upscale(var_x, 512, **kwargs)
|
|
return KerasModel(input_, var_x)
|
|
|
|
def decoder(self):
|
|
""" Decoder Network """
|
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
|
decoder_shape = self.input_shape[0] // 8
|
|
input_ = Input(shape=(decoder_shape, decoder_shape, 512))
|
|
|
|
var_x = input_
|
|
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, 512, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, 256, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs)
|
|
var_x = self.blocks.conv2d(var_x, 3,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="face_out")
|
|
outputs = [var_x]
|
|
|
|
if self.config.get("learn_mask", False):
|
|
var_y = input_
|
|
var_y = self.blocks.upscale(var_y, 512)
|
|
var_y = self.blocks.upscale(var_y, 256)
|
|
var_y = self.blocks.upscale(var_y, self.input_shape[0])
|
|
var_y = self.blocks.conv2d(var_y, 1,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="mask_out")
|
|
outputs.append(var_y)
|
|
return KerasModel(input_, outputs=outputs)
|