1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/train/model/villain.py

92 lines
4.1 KiB
Python

#!/usr/bin/env python3
""" Original - VillainGuy model
Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
Adapted from a model by VillainGuy (https://github.com/VillainGuy) """
from keras.initializers import RandomNormal
from keras.layers import add, Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from lib.model.layers import PixelShuffler
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Villain Faceswap Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
self.configfile = kwargs.get("configfile", None)
kwargs["input_shape"] = (128, 128, 3)
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
self.kernel_initializer = RandomNormal(0, 0.02)
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def encoder(self):
""" Encoder Network """
kwargs = dict(kernel_initializer=self.kernel_initializer)
input_ = Input(shape=self.input_shape)
in_conv_filters = self.input_shape[0]
if self.input_shape[0] > 128:
in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
dense_shape = self.input_shape[0] // 16
var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs)
tmp_x = var_x
res_cycles = 8 if self.config.get("lowmem", False) else 16
for _ in range(res_cycles):
nn_x = self.blocks.res_block(var_x, in_conv_filters, **kwargs)
var_x = nn_x
# consider adding scale before this layer to scale the residual chain
var_x = add([var_x, tmp_x])
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = PixelShuffler()(var_x)
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = PixelShuffler()(var_x)
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = self.blocks.conv_sep(var_x, 256, **kwargs)
var_x = self.blocks.conv(var_x, 512, **kwargs)
if not self.config.get("lowmem", False):
var_x = self.blocks.conv_sep(var_x, 1024, **kwargs)
var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
var_x = self.blocks.upscale(var_x, 512, **kwargs)
return KerasModel(input_, var_x)
def decoder(self):
""" Decoder Network """
kwargs = dict(kernel_initializer=self.kernel_initializer)
decoder_shape = self.input_shape[0] // 8
input_ = Input(shape=(decoder_shape, decoder_shape, 512))
var_x = input_
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, 512, **kwargs)
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, 256, **kwargs)
var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs)
var_x = self.blocks.conv2d(var_x, 3,
kernel_size=5,
padding="same",
activation="sigmoid",
name="face_out")
outputs = [var_x]
if self.config.get("learn_mask", False):
var_y = input_
var_y = self.blocks.upscale(var_y, 512)
var_y = self.blocks.upscale(var_y, 256)
var_y = self.blocks.upscale(var_y, self.input_shape[0])
var_y = self.blocks.conv2d(var_y, 1,
kernel_size=5,
padding="same",
activation="sigmoid",
name="mask_out")
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)