mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
* Smart Masks - Training - Reinstate smart mask training code - Reinstate mask_type back to model.config - change 'replicate_input_mask to 'learn_mask' - Add learn mask option - Add mask loading from alignments to plugins.train.trainer - Add mask_blur and mask threshold options - _base.py - Pass mask options through training_opts dict - plugins.train.model - check for mask_type not None for learn_mask and penalized_mask_loss - Limit alignments loading to just those faces that appear in the training folder - Raise error if not all training images have an alignment, and alignment file is required - lib.training_data - Mask generation code - lib.faces_detect - cv2 dimension stripping bugfix - Remove cv2 linting code * Update mask helptext in cli.py * Fix Warp to Landmarks Remove SHA1 hashing from training data * Update mask training config * Capture missing masks at training init * lib.image.read_image_batch - Return filenames with batch for ordering * scripts.train - Documentation * plugins.train.trainer - documentation * Ensure backward compatibility. Fix convert for new predicted masks * Update removed masks to components for legacy models.
147 lines
7.2 KiB
Python
147 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
|
""" Unbalanced Model
|
|
Based on the original https://www.reddit.com/r/deepfakes/
|
|
code sample + contribs """
|
|
|
|
from keras.initializers import RandomNormal
|
|
from keras.layers import Dense, Flatten, Input, Reshape, SpatialDropout2D
|
|
from keras.models import Model as KerasModel
|
|
|
|
from .original import logger, Model as OriginalModel
|
|
|
|
|
|
class Model(OriginalModel):
|
|
""" Unbalanced Faceswap Model """
|
|
def __init__(self, *args, **kwargs):
|
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
|
self.__class__.__name__, args, kwargs)
|
|
|
|
self.configfile = kwargs.get("configfile", None)
|
|
self.lowmem = self.config.get("lowmem", False)
|
|
kwargs["input_shape"] = (self.config["input_size"], self.config["input_size"], 3)
|
|
kwargs["encoder_dim"] = 512 if self.lowmem else self.config["nodes"]
|
|
self.kernel_initializer = RandomNormal(0, 0.02)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def add_networks(self):
|
|
""" Add the original model weights """
|
|
logger.debug("Adding networks")
|
|
self.add_network("decoder", "a", self.decoder_a(), is_output=True)
|
|
self.add_network("decoder", "b", self.decoder_b(), is_output=True)
|
|
self.add_network("encoder", None, self.encoder())
|
|
logger.debug("Added networks")
|
|
|
|
def encoder(self):
|
|
""" Unbalanced Encoder """
|
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
|
encoder_complexity = 128 if self.lowmem else self.config["complexity_encoder"]
|
|
dense_dim = 384 if self.lowmem else 512
|
|
dense_shape = self.input_shape[0] // 16
|
|
input_ = Input(shape=self.input_shape)
|
|
|
|
var_x = input_
|
|
var_x = self.blocks.conv(var_x, encoder_complexity, use_instance_norm=True, **kwargs)
|
|
var_x = self.blocks.conv(var_x, encoder_complexity * 2, use_instance_norm=True, **kwargs)
|
|
var_x = self.blocks.conv(var_x, encoder_complexity * 4, **kwargs)
|
|
var_x = self.blocks.conv(var_x, encoder_complexity * 6, **kwargs)
|
|
var_x = self.blocks.conv(var_x, encoder_complexity * 8, **kwargs)
|
|
var_x = Dense(self.encoder_dim,
|
|
kernel_initializer=self.kernel_initializer)(Flatten()(var_x))
|
|
var_x = Dense(dense_shape * dense_shape * dense_dim,
|
|
kernel_initializer=self.kernel_initializer)(var_x)
|
|
var_x = Reshape((dense_shape, dense_shape, dense_dim))(var_x)
|
|
return KerasModel(input_, var_x)
|
|
|
|
def decoder_a(self):
|
|
""" Decoder for side A """
|
|
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
|
|
decoder_complexity = 320 if self.lowmem else self.config["complexity_decoder_a"]
|
|
dense_dim = 384 if self.lowmem else 512
|
|
decoder_shape = self.input_shape[0] // 16
|
|
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
|
|
|
|
var_x = input_
|
|
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
|
var_x = SpatialDropout2D(0.25)(var_x)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
|
if self.lowmem:
|
|
var_x = SpatialDropout2D(0.15)(var_x)
|
|
else:
|
|
var_x = SpatialDropout2D(0.25)(var_x)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
|
var_x = self.blocks.conv2d(var_x, 3,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="face_out")
|
|
outputs = [var_x]
|
|
|
|
if self.config.get("learn_mask", False):
|
|
var_y = input_
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
|
|
var_y = self.blocks.conv2d(var_y, 1,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="mask_out")
|
|
outputs.append(var_y)
|
|
return KerasModel(input_, outputs=outputs)
|
|
|
|
def decoder_b(self):
|
|
""" Decoder for side B """
|
|
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
|
|
dense_dim = 384 if self.lowmem else self.config["complexity_decoder_b"]
|
|
decoder_complexity = 384 if self.lowmem else 512
|
|
decoder_shape = self.input_shape[0] // 16
|
|
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
|
|
|
|
var_x = input_
|
|
if self.lowmem:
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 8, **kwargs)
|
|
else:
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity,
|
|
res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, decoder_complexity,
|
|
kernel_initializer=self.kernel_initializer)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity,
|
|
res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, decoder_complexity,
|
|
kernel_initializer=self.kernel_initializer)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2,
|
|
res_block_follows=True, **kwargs)
|
|
var_x = self.blocks.res_block(var_x, decoder_complexity // 2,
|
|
kernel_initializer=self.kernel_initializer)
|
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
|
var_x = self.blocks.conv2d(var_x, 3,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="face_out")
|
|
outputs = [var_x]
|
|
|
|
if self.config.get("learn_mask", False):
|
|
var_y = input_
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
|
if not self.lowmem:
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
|
|
if self.lowmem:
|
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 8)
|
|
var_y = self.blocks.conv2d(var_y, 1,
|
|
kernel_size=5,
|
|
padding="same",
|
|
activation="sigmoid",
|
|
name="mask_out")
|
|
outputs.append(var_y)
|
|
return KerasModel(input_, outputs=outputs)
|