1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/train/model/lightweight.py
torzdf 43a4d06540
Smart Masks - Training Implementation (#914)
* Smart Masks - Training

- Reinstate smart mask training code
- Reinstate mask_type back to model.config
- change 'replicate_input_mask to 'learn_mask'
- Add learn mask option
- Add mask loading from alignments to plugins.train.trainer
- Add mask_blur and mask threshold options
- _base.py - Pass mask options through training_opts dict
- plugins.train.model - check for mask_type not None for learn_mask and penalized_mask_loss
- Limit alignments loading to just those faces that appear in the training folder
- Raise error if not all training images have an alignment, and alignment file is required
- lib.training_data - Mask generation code
- lib.faces_detect - cv2 dimension stripping bugfix
- Remove cv2 linting code

* Update mask helptext in cli.py

* Fix Warp to Landmarks
Remove SHA1 hashing from training data

* Update mask training config

* Capture missing masks at training init

* lib.image.read_image_batch - Return filenames with batch for ordering

* scripts.train - Documentation

* plugins.train.trainer - documentation

* Ensure backward compatibility.
Fix convert for new predicted masks

* Update removed masks to components for legacy models.
2019-12-05 16:02:01 +00:00

61 lines
2.4 KiB
Python

#!/usr/bin/env python3
""" Original Model
Based on the original https://www.reddit.com/r/deepfakes/
code sample + contribs """
from keras.layers import Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Lightweight Model for ~2GB Graphics Cards """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
kwargs["input_shape"] = (64, 64, 3)
kwargs["encoder_dim"] = 512
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def encoder(self):
""" Encoder Network """
input_ = Input(shape=self.input_shape)
var_x = input_
var_x = self.blocks.conv(var_x, 128)
var_x = self.blocks.conv(var_x, 256)
var_x = self.blocks.conv(var_x, 512)
var_x = Dense(self.encoder_dim)(Flatten()(var_x))
var_x = Dense(4 * 4 * 512)(var_x)
var_x = Reshape((4, 4, 512))(var_x)
var_x = self.blocks.upscale(var_x, 256)
return KerasModel(input_, var_x)
def decoder(self):
""" Decoder Network """
input_ = Input(shape=(8, 8, 256))
var_x = input_
var_x = self.blocks.upscale(var_x, 512)
var_x = self.blocks.upscale(var_x, 256)
var_x = self.blocks.upscale(var_x, 128)
var_x = self.blocks.conv2d(var_x, 3,
kernel_size=5,
padding="same",
activation="sigmoid",
name="face_out")
outputs = [var_x]
if self.config.get("learn_mask", False):
var_y = input_
var_y = self.blocks.upscale(var_y, 512)
var_y = self.blocks.upscale(var_y, 256)
var_y = self.blocks.upscale(var_y, 128)
var_y = self.blocks.conv2d(var_y, 1,
kernel_size=5,
padding="same",
activation="sigmoid",
name="mask_out")
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)