1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00

model_refactor (#571) (#572)

* model_refactor (#571)

* original model to new structure

* IAE model to new structure

* OriginalHiRes to new structure

* Fix trainer for different resolutions

* Initial config implementation

* Configparse library added

* improved training data loader

* dfaker model working

* Add logging to training functions

* Non blocking input for cli training

* Add error handling to threads. Add non-mp queues to queue_handler

* Improved Model Building and NNMeta

* refactor lib/models

* training refactor. DFL H128 model Implementation

* Dfaker - use hashes

* Move timelapse. Remove perceptual loss arg

* Update INSTALL.md. Add logger formatting. Update Dfaker training

* DFL h128 partially ported

* Add mask to dfaker (#573)

* Remove old models. Add mask to dfaker

* dfl mask. Make masks selectable in config (#575)

* DFL H128 Mask. Mask type selectable in config.

* remove gan_v2_2

* Creating Input Size config for models

Creating Input Size config for models

Will be used downstream in converters.

Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes)

* Add mask loss options to config

* MTCNN options to config.ini. Remove GAN config. Update USAGE.md

* Add sliders for numerical values in GUI

* Add config plugins menu to gui. Validate config

* Only backup model if loss has dropped. Get training working again

* bugfixes

* Standardise loss printing

* GUI idle cpu fixes. Graph loss fix.

* mutli-gpu logging bugfix

* Merge branch 'staging' into train_refactor

* backup state file

* Crash protection: Only backup if both total losses have dropped

* Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)

* Load and save model structure with weights

* Slight code update

* Improve config loader. Add subpixel opt to all models. Config to state

* Show samples... wrong input

* Remove AE topology. Add input/output shapes to State

* Port original_villain (birb/VillainGuy) model to faceswap

* Add plugin info to GUI config pages

* Load input shape from state. IAE Config options.

* Fix transform_kwargs.
Coverage to ratio.
Bugfix mask detection

* Suppress keras userwarnings.
Automate zoom.
Coverage_ratio to model def.

* Consolidation of converters & refactor (#574)

* Consolidation of converters & refactor

Initial Upload of alpha

Items
- consolidate convert_mased & convert_adjust into one converter
-add average color adjust to convert_masked
-allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size
-allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size
-eliminate redundant type conversions to avoid multiple round-off errors
-refactor loops for vectorization/speed
-reorganize for clarity & style changes

TODO
- bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now
- issues with mask border giving black ring at zero erosion .. investigate
- remove GAN ??
- test enlargment factors of umeyama standard face .. match to coverage factor
- make enlargment factor a model parameter
- remove convert_adjusted and referencing code when finished

* Update Convert_Masked.py

default blur size of 2 to match original...
description of enlargement tests
breakout matrxi scaling into def

* Enlargment scale as a cli parameter

* Update cli.py

* dynamic interpolation algorithm

Compute x & y scale factors from the affine matrix on the fly by QR decomp.
Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image

* input size
input size from config

* fix issues with <1.0 erosion

* Update convert.py

* Update Convert_Adjust.py

more work on the way to merginf

* Clean up help note on sharpen

* cleanup seamless

* Delete Convert_Adjust.py

* Update umeyama.py

* Update training_data.py

* swapping

* segmentation stub

* changes to convert.str

* Update masked.py

* Backwards compatibility fix for models
Get converter running

* Convert:
Move masks to class.
bugfix blur_size
some linting

* mask fix

* convert fixes

- missing facehull_rect re-added
- coverage to %
- corrected coverage logic
- cleanup of gui option ordering

* Update cli.py

* default for blur

* Update masked.py

* added preliminary low_mem version of OriginalHighRes model plugin

* Code cleanup, minor fixes

* Update masked.py

* Update masked.py

* Add dfl mask to convert

* histogram fix & seamless location

* update

* revert

* bugfix: Load actual configuration in gui

* Standardize nn_blocks

* Update cli.py

* Minor code amends

* Fix Original HiRes model

* Add masks to preview output for mask trainers
refactor trainer.__base.py

* Masked trainers converter support

* convert bugfix

* Bugfix: Converter for masked (dfl/dfaker) trainers

* Additional Losses (#592)

* initial upload

* Delete blur.py

* default initializer = He instead of Glorot (#588)

* Allow kernel_initializer to be overridable

* Add ICNR Initializer option for upscale on all models.

* Hopefully fixes RSoDs with original-highres model plugin

* remove debug line

* Original-HighRes model plugin Red Screen of Death fix, take #2

* Move global options to _base. Rename Villain model

* clipnorm and res block biases

* scale the end of res block

* res block

* dfaker pre-activation res

* OHRES pre-activation

* villain pre-activation

* tabs/space in nn_blocks

* fix for histogram with mask all set to zero

* fix to prevent two networks with same name

* GUI: Wider tooltips. Improve TQDM capture

* Fix regex bug

* Convert padding=48 to ratio of image size

* Add size option to alignments tool extract

* Pass through training image size to convert from model

* Convert: Pull training coverage from model

* convert: coverage, blur and erode to percent

* simplify matrix scaling

* ordering of sliders in train

* Add matrix scaling to utils. Use interpolation in lib.aligner transform

* masked.py Import get_matrix_scaling from utils

* fix circular import

* Update masked.py

* quick fix for matrix scaling

* testing thus for now

* tqdm regex capture bugfix

* Minor ammends

* blur size cleanup

* Remove coverage option from convert (Now cascades from model)

* Implement convert for all model types

* Add mask option and coverage option to all existing models

* bugfix for model loading on convert

* debug print removal

* Bugfix for masks in dfl_h128 and iae

* Update preview display. Add preview scaling to cli

* mask notes

* Delete training_data_v2.py

errant file

* training data variables

* Fix timelapse function

* Add new config items to state file for legacy purposes

* Slight GUI tweak

* Raise exception if problem with loaded model

* Add Tensorboard support (Logs stored in model directory)

* ICNR fix

* loss bugfix

* convert bugfix

* Move ini files to config folder. Make TensorBoard optional

* Fix training data for unbalanced inputs/outputs

* Fix config "none" test

* Keep helptext in .ini files when saving config from GUI

* Remove frame_dims from alignments

* Add no-flip and warp-to-landmarks cli options

* Revert OHR to RC4_fix version

* Fix lowmem mode on OHR model

* padding to variable

* Save models in parallel threads

* Speed-up of res_block stability

* Automated Reflection Padding

* Reflect Padding as a training option

Includes auto-calculation of proper padding shapes, input_shapes, output_shapes

Flag included in config now

* rest of reflect padding

* Move TB logging to cli. Session info to state file

* Add session iterations to state file

* Add recent files to menu. GUI code tidy up

* [GUI] Fix recent file list update issue

* Add correct loss names to TensorBoard logs

* Update live graph to use TensorBoard and remove animation

* Fix analysis tab. GUI optimizations

* Analysis Graph popup to Tensorboard Logs

* [GUI] Bug fix for graphing for models with hypens in name

* [GUI] Correctly split loss to tabs during training

* [GUI] Add loss type selection to analysis graph

* Fix store command name in recent files. Switch to correct tab on open

* [GUI] Disable training graph when 'no-logs' is selected

* Fix graphing race condition

* rename original_hires model to unbalanced
This commit is contained in:
torzdf 2019-02-09 18:35:12 +00:00 committed by GitHub
parent 584c41e005
commit cd00859c40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
94 changed files with 7435 additions and 4251 deletions

View file

@ -1,5 +1,7 @@
**Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.** **Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.**
**Please always attach your generated crash_report.log to any bug report**
## Expected behavior ## Expected behavior
*Describe, in some detail, what you are trying to do and what the output is that you expect from the program.* *Describe, in some detail, what you are trying to do and what the output is that you expect from the program.*

9
.gitignore vendored
View file

@ -12,8 +12,9 @@
!Dockerfile* !Dockerfile*
!requirements* !requirements*
!.cache !.cache
!lib !config/
!lib/face_alignment !lib/
!lib/*
!lib/gui !lib/gui
!lib/gui/.cache/preview !lib/gui/.cache/preview
!lib/gui/.cache/icons !lib/gui/.cache/icons
@ -21,9 +22,9 @@
!plugins/ !plugins/
!plugins/* !plugins/*
!plugins/extract/* !plugins/extract/*
!plugins/model/* !plugins/train/*
!tools !tools
!tools/lib* !tools/lib*
*.ini
*.pyc *.pyc
__pycache__/ __pycache__/

View file

@ -1,37 +1,37 @@
# Installing Faceswap # Installing Faceswap
- [Installing Faceswap](#installing-faceswap) - [Installing Faceswap](#installing-faceswap)
- [Prerequisites](#prerequisites) - [Prerequisites](#prerequisites)
- [Hardware Requirements](#hardware-requirements) - [Hardware Requirements](#hardware-requirements)
- [Supported operating systems](#supported-operating-systems) - [Supported operating systems](#supported-operating-systems)
- [Important before you proceed](#important-before-you-proceed) - [Important before you proceed](#important-before-you-proceed)
- [General Install Guide](#general-install-guide) - [General Install Guide](#general-install-guide)
- [Installing dependencies](#installing-dependencies) - [Installing dependencies](#installing-dependencies)
- [Getting the faceswap code](#getting-the-faceswap-code) - [Getting the faceswap code](#getting-the-faceswap-code)
- [Setup](#setup) - [Setup](#setup)
- [About some of the options](#about-some-of-the-options) - [About some of the options](#about-some-of-the-options)
- [Run the project](#run-the-project) - [Run the project](#run-the-project)
- [Notes](#notes) - [Notes](#notes)
- [Windows Install Guide](#windows-install-guide) - [Windows Install Guide](#windows-install-guide)
- [Prerequisites](#prerequisites-1) - [Prerequisites](#prerequisites-1)
- [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015) - [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015)
- [Cuda](#cuda) - [Cuda](#cuda)
- [cuDNN](#cudnn) - [cuDNN](#cudnn)
- [CMake](#cmake) - [CMake](#cmake)
- [Anaconda](#anaconda) - [Anaconda](#anaconda)
- [Git](#git) - [Git](#git)
- [Setup](#setup-1) - [Setup](#setup-1)
- [Anaconda](#anaconda-1) - [Anaconda](#anaconda-1)
- [Set up a virtual environment](#set-up-a-virtual-environment) - [Set up a virtual environment](#set-up-a-virtual-environment)
- [Entering your virtual environment](#entering-your-virtual-environment) - [Entering your virtual environment](#entering-your-virtual-environment)
- [Faceswap](#faceswap) - [Faceswap](#faceswap)
- [Easy install](#easy-install) - [Easy install](#easy-install)
- [Manual install](#manual-install) - [Manual install](#manual-install)
- [Running Faceswap](#running-faceswap) - [Running Faceswap](#running-faceswap)
- [Create a desktop shortcut](#create-a-desktop-shortcut) - [Create a desktop shortcut](#create-a-desktop-shortcut)
- [Updating faceswap](#updating-faceswap) - [Updating faceswap](#updating-faceswap)
- [Dlib](#dlib) - [Dlib](#dlib)
- [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support) - [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support)
- [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support) - [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support)
# Prerequisites # Prerequisites
Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up. Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up.

View file

@ -34,6 +34,8 @@ You can see the full list of arguments for extracting via help flag. i.e.
python faceswap.py extract -h python faceswap.py extract -h
``` ```
Some of the plugins have configurable options. You can find the config options in: `<faceswap_folder>\plugins\extract\config.ini`. Extract needs to have been run at least once to generate the config file
## TRAIN ## TRAIN
The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results. The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results.
@ -51,6 +53,9 @@ You can see the full list of arguments for training via help flag. i.e.
python faceswap.py train -h python faceswap.py train -h
``` ```
Some of the plugins have configurable options. You can find the config options in: `<faceswap_folder>\plugins\traom\config.ini`. Train needs to have been run at least once to generate the config file
## CONVERT ## CONVERT
Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture. Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture.
@ -86,7 +91,7 @@ python tools.py effmpeg -h
``` ```
## Extracting video frames with FFMPEG ## Extracting video frames with FFMPEG
Alternatively you can split a video into seperate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to seperate frames. Alternatively you can split a video into separate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to separate frames.
```bash ```bash
ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png

View file

@ -1,88 +0,0 @@
# PixelShuffler layer for Keras
# by t-ae
# https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
from keras.utils import conv_utils
from keras.engine.topology import Layer
import keras.backend as K
class PixelShuffler(Layer):
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
def call(self, inputs):
input_shape = K.int_shape(inputs)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = K.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
height,
width,
channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View file

@ -12,29 +12,6 @@ from lib.align_eyes import align_eyes as func_align_eyes, FACIAL_LANDMARKS_IDXS
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
MEAN_FACE_X = np.array([
0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483,
0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127,
0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122,
0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132,
0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127,
.551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532,
0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364,
0.490127, 0.42689])
MEAN_FACE_Y = np.array([
0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891,
0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625,
0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758,
0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758,
0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578,
0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192,
0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182,
0.831803, 0.824182])
LANDMARKS_2D = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1)
class Extract(): class Extract():
""" Based on the original https://www.reddit.com/r/deepfakes/ """ Based on the original https://www.reddit.com/r/deepfakes/
code sample + contribs """ code sample + contribs """
@ -42,8 +19,9 @@ class Extract():
def extract(self, image, face, size, align_eyes): def extract(self, image, face, size, align_eyes):
""" Extract a face from an image """ """ Extract a face from an image """
logger.trace("size: %s. align_eyes: %s", size, align_eyes) logger.trace("size: %s. align_eyes: %s", size, align_eyes)
padding = int(size * 0.1875)
alignment = get_align_mat(face, size, align_eyes) alignment = get_align_mat(face, size, align_eyes)
extracted = self.transform(image, alignment, size, 48) extracted = self.transform(image, alignment, size, padding)
logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment) logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment)
return extracted, alignment return extracted, alignment
@ -60,8 +38,9 @@ class Extract():
""" Transform Image """ """ Transform Image """
logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding) logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding)
matrix = self.transform_matrix(mat, size, padding) matrix = self.transform_matrix(mat, size, padding)
interpolators = get_matrix_scaling(matrix)
return cv2.warpAffine( # pylint: disable=no-member return cv2.warpAffine( # pylint: disable=no-member
image, matrix, (size, size)) image, matrix, (size, size), flags=interpolators[0])
def transform_points(self, points, mat, size, padding=0): def transform_points(self, points, mat, size, padding=0):
""" Transform points along matrix """ """ Transform points along matrix """
@ -144,12 +123,23 @@ class Extract():
return mask return mask
def get_matrix_scaling(mat):
""" Get the correct interpolator """
x_scale = np.sqrt(mat[0, 0] * mat[0, 0] + mat[0, 1] * mat[0, 1])
y_scale = (mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]) / x_scale
avg_scale = (x_scale + y_scale) * 0.5
if avg_scale >= 1.0:
interpolators = cv2.INTER_CUBIC, cv2.INTER_AREA # pylint: disable=no-member
else:
interpolators = cv2.INTER_AREA, cv2.INTER_CUBIC # pylint: disable=no-member
logger.trace("interpolator: %s, inverse interpolator: %s", interpolators[0], interpolators[1])
return interpolators
def get_align_mat(face, size, should_align_eyes): def get_align_mat(face, size, should_align_eyes):
""" Return the alignment Matrix """ """ Return the alignment Matrix """
logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes) logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes)
mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), True)[0:2]
LANDMARKS_2D,
True)[0:2]
if should_align_eyes is False: if should_align_eyes is False:
return mat_umeyama return mat_umeyama

View file

@ -270,33 +270,6 @@ class Alignments():
# << LEGACY FUNCTIONS >> # # << LEGACY FUNCTIONS >> #
# < Original Frame Dimensions > #
# For dfaker and convert-adjust the original dimensions of a frame are
# required to calculate the transposed landmarks. As transposed landmarks
# will change on face size, we store original frame dimensions
# These were not previously required, so this adds the dimensions
# to the landmarks file
def get_legacy_no_dims(self):
""" Return a list of frames that do not contain the original frame
height and width attributes """
logger.debug("Getting alignments without frame_dims")
keys = list()
for key, val in self.data.items():
for alignment in val:
if "frame_dims" not in alignment.keys():
keys.append(key)
break
logger.debug("Got alignments without frame_dims: %s", len(keys))
return keys
def add_dimensions(self, frame_name, dimensions):
""" Backward compatability fix. Add frame dimensions
to alignments """
logger.trace("Adding dimensions: (frame: '%s', dimensions: %s)", frame_name, dimensions)
for face in self.get_faces_in_frame(frame_name):
face["frame_dims"] = dimensions
# < Rotation > # # < Rotation > #
# The old rotation method would rotate the image to find a face, then # The old rotation method would rotate the image to find a face, then
# store the rotated landmarks along with a rotation value to tell the # store the rotated landmarks along with a rotation value to tell the
@ -319,20 +292,20 @@ class Alignments():
logger.debug("Got alignments containing legacy rotations: %s", len(keys)) logger.debug("Got alignments containing legacy rotations: %s", len(keys))
return keys return keys
def rotate_existing_landmarks(self, frame_name): def rotate_existing_landmarks(self, frame_name, frame):
""" Backwards compatability fix. Rotates the landmarks to """ Backwards compatability fix. Rotates the landmarks to
their correct position and deletes r their correct position and deletes r
NB: The original frame dimensions must be passed in otherwise NB: The original frame must be passed in otherwise
the transformation cannot be performed """ the transformation cannot be performed """
logger.trace("Rotating existing landmarks for frame: '%s'", frame_name) logger.trace("Rotating existing landmarks for frame: '%s'", frame_name)
dims = frame.shape[:2]
for face in self.get_faces_in_frame(frame_name): for face in self.get_faces_in_frame(frame_name):
angle = face.get("r", 0) angle = face.get("r", 0)
if not angle: if not angle:
logger.trace("Landmarks do not require rotation: '%s'", frame_name) logger.trace("Landmarks do not require rotation: '%s'", frame_name)
return return
logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle) logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle)
dims = face["frame_dims"]
r_mat = self.get_original_rotation_matrix(dims, angle) r_mat = self.get_original_rotation_matrix(dims, angle)
rotate_landmarks(face, r_mat) rotate_landmarks(face, r_mat)
del face["r"] del face["r"]

View file

@ -103,9 +103,41 @@ class ScriptExecutor():
safe_shutdown() safe_shutdown()
class FullPaths(argparse.Action): class Slider(argparse.Action): # pylint: disable=too-few-public-methods
""" Adds support for the GUI slider
An additional option 'min_max' must be provided containing tuple of min and max accepted
values.
'rounding' sets the decimal places for floats or the step interval for ints.
"""
def __init__(self, option_strings, dest, nargs=None, min_max=None, rounding=None, **kwargs):
if nargs is not None:
raise ValueError("nargs not allowed")
super().__init__(option_strings, dest, **kwargs)
self.min_max = min_max
self.rounding = rounding
def _get_kwargs(self):
names = ["option_strings",
"dest",
"nargs",
"const",
"default",
"type",
"choices",
"help",
"metavar",
"min_max", # Tuple containing min and max values of scale
"rounding"] # Decimal places to round floats to or step interval for ints
return [(name, getattr(self, name)) for name in names]
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
class FullPaths(argparse.Action): # pylint: disable=too-few-public-methods
""" Expand user- and relative-paths """ """ Expand user- and relative-paths """
# pylint: disable=too-few-public-methods
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, os.path.abspath( setattr(namespace, self.dest, os.path.abspath(
os.path.expanduser(values))) os.path.expanduser(values)))
@ -124,26 +156,23 @@ class FileFullPaths(FullPaths):
see lib/gui/utils.py FileHandler for current GUI filetypes see lib/gui/utils.py FileHandler for current GUI filetypes
""" """
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
def __init__(self, option_strings, dest, nargs=None, filetypes=None, def __init__(self, option_strings, dest, nargs=None, filetypes=None, **kwargs):
**kwargs):
super(FileFullPaths, self).__init__(option_strings, dest, **kwargs) super(FileFullPaths, self).__init__(option_strings, dest, **kwargs)
if nargs is not None: if nargs is not None:
raise ValueError("nargs not allowed") raise ValueError("nargs not allowed")
self.filetypes = filetypes self.filetypes = filetypes
def _get_kwargs(self): def _get_kwargs(self):
names = [ names = ["option_strings",
"option_strings", "dest",
"dest", "nargs",
"nargs", "const",
"const", "default",
"default", "type",
"type", "choices",
"choices", "help",
"help", "metavar",
"metavar", "filetypes"]
"filetypes"
]
return [(name, getattr(self, name)) for name in names] return [(name, getattr(self, name)) for name in names]
@ -185,19 +214,17 @@ class ContextFullPaths(FileFullPaths):
self.filetypes = filetypes self.filetypes = filetypes
def _get_kwargs(self): def _get_kwargs(self):
names = [ names = ["option_strings",
"option_strings", "dest",
"dest", "nargs",
"nargs", "const",
"const", "default",
"default", "type",
"type", "choices",
"choices", "help",
"help", "metavar",
"metavar", "filetypes",
"filetypes", "action_option"]
"action_option"
]
return [(name, getattr(self, name)) for name in names] return [(name, getattr(self, name)) for name in names]
@ -282,6 +309,13 @@ class FaceSwapArgs():
"help": "Path to store the logfile. Leave blank to store in the " "help": "Path to store the logfile. Leave blank to store in the "
"faceswap folder", "faceswap folder",
"default": None}) "default": None})
# This is a hidden argument to indicate that the GUI is being used,
# so the preview window should be redirected Accordingly
global_args.append({"opts": ("-gui", "--gui"),
"action": "store_true",
"dest": "redirect_gui",
"default": False,
"help": argparse.SUPPRESS})
return global_args return global_args
@staticmethod @staticmethod
@ -342,11 +376,14 @@ class ExtractConvertArgs(FaceSwapArgs):
"dest": "alignments_path", "dest": "alignments_path",
"help": "Optional path to an alignments file."}) "help": "Optional path to an alignments file."})
argument_list.append({"opts": ("-l", "--ref_threshold"), argument_list.append({"opts": ("-l", "--ref_threshold"),
"action": Slider,
"min_max": (0.01, 0.99),
"rounding": 2,
"type": float, "type": float,
"dest": "ref_threshold", "dest": "ref_threshold",
"default": 0.6, "default": 0.6,
"help": "Threshold for positive face " "help": "Threshold for positive face recognition. For use with "
"recognition"}) "nfilter or filter. Lower values are stricter."})
argument_list.append({"opts": ("-n", "--nfilter"), argument_list.append({"opts": ("-n", "--nfilter"),
"type": str, "type": str,
"dest": "nfilter", "dest": "nfilter",
@ -389,7 +426,7 @@ class ExtractArgs(ExtractConvertArgs):
"fallback."}) "fallback."})
argument_list.append({ argument_list.append({
"opts": ("-D", "--detector"), "opts": ("-D", "--detector"),
"type": str, "type": str.lower,
"choices": PluginLoader.get_available_extractors( "choices": PluginLoader.get_available_extractors(
"detect"), "detect"),
"default": "mtcnn", "default": "mtcnn",
@ -404,7 +441,7 @@ class ExtractArgs(ExtractConvertArgs):
"\n\talignment to dlib"}) "\n\talignment to dlib"})
argument_list.append({ argument_list.append({
"opts": ("-A", "--aligner"), "opts": ("-A", "--aligner"),
"type": str, "type": str.lower,
"choices": PluginLoader.get_available_extractors( "choices": PluginLoader.get_available_extractors(
"align"), "align"),
"default": "fan", "default": "fan",
@ -413,38 +450,6 @@ class ExtractArgs(ExtractConvertArgs):
"\n\tresource intensive, but less accurate." "\n\tresource intensive, but less accurate."
"\n'fan': Face Alignment Network. Best aligner." "\n'fan': Face Alignment Network. Best aligner."
"\n\tGPU heavy."}) "\n\tGPU heavy."})
argument_list.append({"opts": ("-mtms", "--mtcnn-minsize"),
"type": int,
"dest": "mtcnn_minsize",
"default": 20,
"help": "The minimum size of a face to be "
"accepted. Lower values use "
"significantly more VRAM. Minimum "
"value is 10. Default is 20 "
"(MTCNN detector only)"})
argument_list.append({"opts": ("-mtth", "--mtcnn-threshold"),
"nargs": "+",
"type": str,
"dest": "mtcnn_threshold",
"default": ["0.6", "0.7", "0.7"],
"help": "R|Three step threshold for face "
"detection. Should be\nthree decimal "
"numbers each less than 1. Eg:\n"
"'--mtcnn-threshold 0.6 0.7 0.7'.\n"
"1st stage: obtains face candidates.\n"
"2nd stage: refinement of face "
"candidates.\n3rd stage: further "
"refinement of face candidates.\n"
"Default is 0.6 0.7 0.7 "
"(MTCNN detector only)"})
argument_list.append({"opts": ("-mtsc", "--mtcnn-scalefactor"),
"type": float,
"dest": "mtcnn_scalefactor",
"default": 0.709,
"help": "The scale factor for the image "
"pyramid. Should be a decimal number "
"less than one. Default is 0.709 "
"(MTCNN detector only)"})
argument_list.append({"opts": ("-r", "--rotate-images"), argument_list.append({"opts": ("-r", "--rotate-images"),
"type": str, "type": str,
"dest": "rotate_images", "dest": "rotate_images",
@ -458,13 +463,15 @@ class ExtractArgs(ExtractConvertArgs):
"exactly what angles to check"}) "exactly what angles to check"})
argument_list.append({"opts": ("-bt", "--blur-threshold"), argument_list.append({"opts": ("-bt", "--blur-threshold"),
"type": float, "type": float,
"action": Slider,
"min_max": (0.0, 100.0),
"rounding": 1,
"dest": "blur_thresh", "dest": "blur_thresh",
"default": None, "default": 0.0,
"help": "Automatically discard images blurrier " "help": "Automatically discard images blurrier than the specified "
"than the specified threshold. " "threshold. Discarded images are moved into a \"blurry\" "
"Discarded images are moved into a " "sub-folder. Lower values allow more blur. Set to 0.0 to "
"\"blurry\" sub-folder. Lower values " "turn off."})
"allow more blur"})
argument_list.append({"opts": ("-mp", "--multiprocess"), argument_list.append({"opts": ("-mp", "--multiprocess"),
"action": "store_true", "action": "store_true",
"default": False, "default": False,
@ -476,12 +483,13 @@ class ExtractArgs(ExtractConvertArgs):
"otherwise this is automatic."}) "otherwise this is automatic."})
argument_list.append({"opts": ("-sz", "--size"), argument_list.append({"opts": ("-sz", "--size"),
"type": int, "type": int,
"action": Slider,
"min_max": (128, 512),
"default": 256, "default": 256,
"help": "The output size of extracted faces. " "rounding": 64,
"Make sure that the model you intend " "help": "The output size of extracted faces. Make sure that the "
"to train supports your required " "model you intend to train supports your required size. "
"size. This will only need to be " "This will only need to be changed for hi-res models."})
"changed for hi-res models."})
argument_list.append({"opts": ("-s", "--skip-existing"), argument_list.append({"opts": ("-s", "--skip-existing"),
"action": "store_true", "action": "store_true",
"dest": "skip_existing", "dest": "skip_existing",
@ -512,13 +520,15 @@ class ExtractArgs(ExtractConvertArgs):
argument_list.append({"opts": ("-si", "--save-interval"), argument_list.append({"opts": ("-si", "--save-interval"),
"dest": "save_interval", "dest": "save_interval",
"type": int, "type": int,
"default": None, "action": Slider,
"help": "Automatically save the alignments file " "min_max": (0, 1000),
"after a set amount of frames. Will " "rounding": 10,
"only save at the end of extracting by " "default": 0,
"default. WARNING: Don't interrupt the " "help": "Automatically save the alignments file after a set amount "
"script when writing the file because " "of frames. Will only save at the end of extracting by "
"it might get corrupted."}) "default. WARNING: Don't interrupt the script when writing "
"the file because it might get corrupted. Set to 0 to turn "
"off"})
return argument_list return argument_list
@ -552,57 +562,73 @@ class ConvertArgs(ExtractConvertArgs):
"specified, all faces will be " "specified, all faces will be "
"converted"}) "converted"})
argument_list.append({"opts": ("-t", "--trainer"), argument_list.append({"opts": ("-t", "--trainer"),
"type": str, "type": str.lower,
# case sensitive because this is used to
# load a plug-in.
"choices": PluginLoader.get_available_models(), "choices": PluginLoader.get_available_models(),
"default": PluginLoader.get_default_model(), "default": PluginLoader.get_default_model(),
"help": "Select the trainer that was used to " "help": "Select the trainer that was used to "
"create the model"}) "create the model"})
argument_list.append({"opts": ("-c", "--converter"), argument_list.append({"opts": ("-c", "--converter"),
"type": str,
# case sensitive because this is used
# to load a plugin.
"choices": ("Masked", "Adjust"),
"default": "Masked",
"help": "Converter to use"})
argument_list.append({"opts": ("-b", "--blur-size"),
"type": int,
"default": 2,
"help": "Blur size. (Masked converter only)"})
argument_list.append({"opts": ("-e", "--erosion-kernel-size"),
"dest": "erosion_kernel_size",
"type": int,
"default": None,
"help": "Erosion kernel size. Positive values "
"apply erosion which reduces the edge "
"of the swapped face. Negative values "
"apply dilation which allows the "
"swapped face to cover more space. "
"(Masked converter only)"})
argument_list.append({"opts": ("-M", "--mask-type"),
# lowercase this, because it's just a
# string later on.
"type": str.lower, "type": str.lower,
"dest": "mask_type", "choices": PluginLoader.get_available_converters(),
"choices": ["rect", "default": "masked",
"facehull", "help": "Converter to use"})
"facehullandrect"], argument_list.append({
"default": "facehullandrect", "opts": ("-M", "--mask-type"),
"help": "Mask to use to replace faces. " "type": str.lower,
"(Masked converter only)"}) "dest": "mask_type",
"choices": ["rect",
"ellipse",
"smoothed",
"facehull",
"facehull_rect",
"dfl",
"cnn"],
"default": "facehull_rect",
"help": "R|Mask to use to replace faces."
"\nrect: Rectangle around face."
"\nellipse: Oval around face."
"\nsmoothed: Rectangle around face with smoothing."
"\nfacehull: Face cutout based on landmarks."
"\nfacehull_rect: Rectangle around faces with facehull"
"\n\tbetween the edges of the face and the background."
"\ndfl: A Face Hull mask from DeepFaceLabs."
"\ncnn: Not yet implemented"})
argument_list.append({"opts": ("-b", "--blur-size"),
"type": float,
"action": Slider,
"min_max": (0.0, 100.0),
"rounding": 2,
"default": 5.0,
"help": "Blur kernel size as a percentage of the swap area. Smooths "
"the transition between the swapped face and the background "
"image."})
argument_list.append({"opts": ("-e", "--erosion-size"),
"dest": "erosion_size",
"type": float,
"action": Slider,
"min_max": (-100.0, 100.0),
"rounding": 2,
"default": 0.0,
"help": "Erosion kernel size as a percentage of the mask radius "
"area. Positive values apply erosion which reduces the size "
"of the swapped area. Negative values apply dilation which "
"increases the swapped area"})
argument_list.append({"opts": ("-g", "--gpus"),
"type": int,
"action": Slider,
"min_max": (1, 10),
"rounding": 1,
"default": 1,
"help": "Number of GPUs to use for conversion"})
argument_list.append({"opts": ("-sh", "--sharpen"), argument_list.append({"opts": ("-sh", "--sharpen"),
"type": str.lower, "type": str.lower,
"dest": "sharpen_image", "dest": "sharpen_image",
"choices": ["bsharpen", "gsharpen"], "choices": ["box_filter", "gaussian_filter"],
"default": None, "default": None,
"help": "Use Sharpen Image. bsharpen for Box " "help": "Sharpen the masked facial region of "
"Blur, gsharpen for Gaussian Blur " "the converted images. Choice of filter "
"(Masked converter only)"}) "to use in sharpening process -- box"
argument_list.append({"opts": ("-g", "--gpus"), "filter or gaussian filter."})
"type": int,
"default": 1,
"help": "Number of GPUs to use for conversion"})
argument_list.append({"opts": ("-fr", "--frame-ranges"), argument_list.append({"opts": ("-fr", "--frame-ranges"),
"nargs": "+", "nargs": "+",
"type": str, "type": str,
@ -628,25 +654,25 @@ class ConvertArgs(ExtractConvertArgs):
"action": "store_true", "action": "store_true",
"dest": "seamless_clone", "dest": "seamless_clone",
"default": False, "default": False,
"help": "Use cv2's seamless clone. " "help": "Use cv2's seamless clone function to "
"(Masked converter only)"}) "remove extreme gradients at the mask "
"seam by smoothing colors."})
argument_list.append({"opts": ("-mh", "--match-histogram"), argument_list.append({"opts": ("-mh", "--match-histogram"),
"action": "store_true", "action": "store_true",
"dest": "match_histogram", "dest": "match_histogram",
"default": False, "default": False,
"help": "Use histogram matching. " "help": "Adjust the histogram of each color "
"(Masked converter only)"}) "channel in the swapped reconstruction "
argument_list.append({"opts": ("-sm", "--smooth-mask"), "to equal the histogram of the masked "
"action": "store_true", "area in the orginal image"})
"dest": "smooth_mask",
"default": False,
"help": "Smooth mask (Adjust converter only)"})
argument_list.append({"opts": ("-aca", "--avg-color-adjust"), argument_list.append({"opts": ("-aca", "--avg-color-adjust"),
"action": "store_true", "action": "store_true",
"dest": "avg_color_adjust", "dest": "avg_color_adjust",
"default": False, "default": False,
"help": "Average color adjust. " "help": "Adjust the mean of each color channel "
"(Adjust converter only)"}) " in the swapped reconstruction to "
"equal the mean of the masked area in "
"the orginal image"})
argument_list.append({"opts": ("-dt", "--draw-transparent"), argument_list.append({"opts": ("-dt", "--draw-transparent"),
"action": "store_true", "action": "store_true",
"dest": "draw_transparent", "dest": "draw_transparent",
@ -667,18 +693,38 @@ class TrainArgs(FaceSwapArgs):
argument_list = list() argument_list = list()
argument_list.append({"opts": ("-A", "--input-A"), argument_list.append({"opts": ("-A", "--input-A"),
"action": DirFullPaths, "action": DirFullPaths,
"dest": "input_A", "dest": "input_a",
"default": "input_A", "default": "input_a",
"help": "Input directory. A directory " "help": "Input directory. A directory "
"containing training images for face A. " "containing training images for face A. "
"Defaults to 'input'"}) "Defaults to 'input'"})
argument_list.append({"opts": ("-B", "--input-B"), argument_list.append({"opts": ("-B", "--input-B"),
"action": DirFullPaths, "action": DirFullPaths,
"dest": "input_B", "dest": "input_b",
"default": "input_B", "default": "input_b",
"help": "Input directory. A directory " "help": "Input directory. A directory "
"containing training images for face B. " "containing training images for face B. "
"Defaults to 'input'"}) "Defaults to 'input'"})
argument_list.append({"opts": ("-ala", "--alignments-A"),
"action": FileFullPaths,
"filetypes": 'alignments',
"type": str,
"dest": "alignments_path_a",
"default": None,
"help": "Path to alignments file for training set A. Only required "
"if you are using a masked model or warp-to-landmarks is "
"enabled. Defaults to <input-A>/alignments.json if not "
"provided."})
argument_list.append({"opts": ("-alb", "--alignments-B"),
"action": FileFullPaths,
"filetypes": 'alignments',
"type": str,
"dest": "alignments_path_b",
"default": None,
"help": "Path to alignments file for training set B. Only required "
"if you are using a masked model or warp-to-landmarks is "
"enabled. Defaults to <input-B>/alignments.json if not "
"provided."})
argument_list.append({"opts": ("-m", "--model-dir"), argument_list.append({"opts": ("-m", "--model-dir"),
"action": DirFullPaths, "action": DirFullPaths,
"dest": "model_dir", "dest": "model_dir",
@ -686,32 +732,51 @@ class TrainArgs(FaceSwapArgs):
"help": "Model directory. This is where the " "help": "Model directory. This is where the "
"training data will be stored. " "training data will be stored. "
"Defaults to 'model'"}) "Defaults to 'model'"})
argument_list.append({"opts": ("-s", "--save-interval"),
"type": int,
"dest": "save_interval",
"default": 100,
"help": "Sets the number of iterations before "
"saving the model"})
argument_list.append({"opts": ("-t", "--trainer"), argument_list.append({"opts": ("-t", "--trainer"),
"type": str, "type": str.lower,
"choices": PluginLoader.get_available_models(), "choices": PluginLoader.get_available_models(),
"default": PluginLoader.get_default_model(), "default": PluginLoader.get_default_model(),
"help": "Select which trainer to use, Use " "help": "Select which trainer to use, Use "
"LowMem for cards with less than 2GB of " "LowMem for cards with less than 2GB of "
"VRAM"}) "VRAM"})
argument_list.append({"opts": ("-s", "--save-interval"),
"type": int,
"action": Slider,
"min_max": (10, 1000),
"rounding": 10,
"dest": "save_interval",
"default": 100,
"help": "Sets the number of iterations before saving the model"})
argument_list.append({"opts": ("-bs", "--batch-size"), argument_list.append({"opts": ("-bs", "--batch-size"),
"type": int, "type": int,
"action": Slider,
"min_max": (2, 256),
"rounding": 2,
"dest": "batch_size",
"default": 64, "default": 64,
"help": "Batch size, as a power of 2 " "help": "Batch size, as a power of 2 (64, 128, 256, etc)"})
"(64, 128, 256, etc)"})
argument_list.append({"opts": ("-it", "--iterations"), argument_list.append({"opts": ("-it", "--iterations"),
"type": int, "type": int,
"action": Slider,
"min_max": (0, 5000000),
"rounding": 20000,
"default": 1000000, "default": 1000000,
"help": "Length of training in iterations"}) "help": "Length of training in iterations."})
argument_list.append({"opts": ("-g", "--gpus"), argument_list.append({"opts": ("-g", "--gpus"),
"type": int, "type": int,
"action": Slider,
"min_max": (1, 10),
"rounding": 1,
"default": 1, "default": 1,
"help": "Number of GPUs to use for training"}) "help": "Number of GPUs to use for training"})
argument_list.append({"opts": ("-ps", "--preview-scale"),
"type": int,
"action": Slider,
"dest": "preview_scale",
"min_max": (25, 200),
"rounding": 25,
"default": 100,
"help": "Percentage amount to scale the preview by."})
argument_list.append({"opts": ("-p", "--preview"), argument_list.append({"opts": ("-p", "--preview"),
"action": "store_true", "action": "store_true",
"dest": "preview", "dest": "preview",
@ -724,20 +789,39 @@ class TrainArgs(FaceSwapArgs):
"default": False, "default": False,
"help": "Writes the training result to a file " "help": "Writes the training result to a file "
"even on preview mode"}) "even on preview mode"})
argument_list.append({"opts": ("-pl", "--use-perceptual-loss"),
"action": "store_true",
"dest": "perceptual_loss",
"default": False,
"help": "Use perceptual loss while training"})
argument_list.append({"opts": ("-ag", "--allow-growth"), argument_list.append({"opts": ("-ag", "--allow-growth"),
"action": "store_true", "action": "store_true",
"dest": "allow_growth", "dest": "allow_growth",
"default": False, "default": False,
"help": "Sets allow_growth option of Tensorflow " "help": "Sets allow_growth option of Tensorflow "
"to spare memory on some configs"}) "to spare memory on some configs"})
argument_list.append({"opts": ("-nl", "--no-logs"),
"action": "store_true",
"dest": "no_logs",
"default": False,
"help": "Disables TensorBoard logging. NB: Disabling logs means "
"that you will not be able to use the graph or analysis "
"for this session in the GUI."})
argument_list.append({"opts": ("-wl", "--warp-to-landmarks"),
"action": "store_true",
"dest": "warp_to_landmarks",
"default": False,
"help": "Warps training faces to closely matched Landmarks from the "
"opposite face-set rather than randomly warping the face. "
"This is the 'dfaker' way of doing warping. Alignments "
"files for both sets of faces must be provided if using "
"this option."})
argument_list.append({"opts": ("-nf", "--no-flip"),
"action": "store_true",
"dest": "no_flip",
"default": False,
"help": "To effectively learn, a random set of images are flipped "
"horizontally. Sometimes it is desirable for this not to "
"occur. Generally this should be left off except for "
"during 'fit training'."})
argument_list.append({"opts": ("-tia", "--timelapse-input-A"), argument_list.append({"opts": ("-tia", "--timelapse-input-A"),
"action": DirFullPaths, "action": DirFullPaths,
"dest": "timelapse_input_A", "dest": "timelapse_input_a",
"default": None, "default": None,
"help": "For if you want a timelapse: " "help": "For if you want a timelapse: "
"The input folder for the timelapse. " "The input folder for the timelapse. "
@ -748,7 +832,7 @@ class TrainArgs(FaceSwapArgs):
"--timelapse-input-B parameter."}) "--timelapse-input-B parameter."})
argument_list.append({"opts": ("-tib", "--timelapse-input-B"), argument_list.append({"opts": ("-tib", "--timelapse-input-B"),
"action": DirFullPaths, "action": DirFullPaths,
"dest": "timelapse_input_B", "dest": "timelapse_input_b",
"default": None, "default": None,
"help": "For if you want a timelapse: " "help": "For if you want a timelapse: "
"The input folder for the timelapse. " "The input folder for the timelapse. "
@ -765,13 +849,6 @@ class TrainArgs(FaceSwapArgs):
"If the input folders are supplied but " "If the input folders are supplied but "
"no output folder, it will default to " "no output folder, it will default to "
"your model folder /timelapse/"}) "your model folder /timelapse/"})
# This is a hidden argument to indicate that the GUI is being used,
# so the preview window should be redirected Accordingly
argument_list.append({"opts": ("-gui", "--gui"),
"action": "store_true",
"dest": "redirect_gui",
"default": False,
"help": argparse.SUPPRESS})
return argument_list return argument_list

301
lib/config.py Normal file
View file

@ -0,0 +1,301 @@
#!/usr/bin/env python3
""" Default configurations for faceswap
Extends out configparser funcionality
by checking for default config updates
and returning data in it's correct format """
import logging
import os
import sys
from collections import OrderedDict
from configparser import ConfigParser
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class FaceswapConfig():
""" Config Items """
def __init__(self, section):
""" Init Configuration """
logger.debug("Initializing: %s", self.__class__.__name__)
self.configfile = self.get_config_file()
self.config = ConfigParser(allow_no_value=True)
self.defaults = OrderedDict()
self.config.optionxform = str
self.section = section
self.set_defaults()
self.handle_config()
logger.debug("Initialized: %s", self.__class__.__name__)
def set_defaults(self):
""" Override for plugin specific config defaults
Should be a series of self.add_section() and self.add_item() calls
e.g:
section = "sect_1"
self.add_section(title=section,
info="Section 1 Information")
self.add_item(section=section,
title="option_1",
datatype=bool,
default=False,
info="sect_1 option_1 information")
"""
raise NotImplementedError
@property
def config_dict(self):
""" Collate global options and requested section into a dictionary
with the correct datatypes """
conf = dict()
for sect in ("global", self.section):
if sect not in self.config.sections():
continue
for key in self.config[sect]:
if key.startswith(("#", "\n")): # Skip comments
continue
conf[key] = self.get(sect, key)
return conf
def get(self, section, option):
""" Return a config item in it's correct format """
logger.debug("Getting config item: (section: '%s', option: '%s')", section, option)
datatype = self.defaults[section][option]["type"]
if datatype == bool:
func = self.config.getboolean
elif datatype == int:
func = self.config.getint
elif datatype == float:
func = self.config.getfloat
else:
func = self.config.get
retval = func(section, option)
if isinstance(retval, str) and retval.lower() == "none":
retval = None
logger.debug("Returning item: (type: %s, value: %s)", datatype, retval)
return retval
def get_config_file(self):
""" Return the config file from the calling folder """
dirname = os.path.dirname(sys.modules[self.__module__].__file__)
folder, fname = os.path.split(dirname)
retval = os.path.join(os.path.dirname(folder), "config", "{}.ini".format(fname))
logger.debug("Config File location: '%s'", retval)
return retval
def add_section(self, title=None, info=None):
""" Add a default section to config file """
logger.debug("Add section: (title: '%s', info: '%s')", title, info)
if None in (title, info):
raise ValueError("Default config sections must have a title and "
"information text")
self.defaults[title] = OrderedDict()
self.defaults[title]["helptext"] = info
def add_item(self, section=None, title=None, datatype=str,
default=None, info=None, rounding=None, min_max=None, choices=None):
""" Add a default item to a config section
For int or float values, rounding and min_max must be set
This is for the slider in the GUI. The min/max values are not enforced:
rounding: sets the decimal places for floats or the step interval for ints.
min_max: tuple of min and max accepted values
For str values choices can be set to validate input and create a combo box
in the GUI
"""
logger.debug("Add item: (section: '%s', title: '%s', datatype: '%s', default: '%s', "
"info: '%s', rounding: '%s', min_max: %s, choices: %s)",
section, title, datatype, default, info, rounding, min_max, choices)
choices = list() if not choices else choices
if None in (section, title, default, info):
raise ValueError("Default config items must have a section, "
"title, defult and "
"information text")
if not self.defaults.get(section, None):
raise ValueError("Section does not exist: {}".format(section))
if datatype not in (str, bool, float, int):
raise ValueError("'datatype' must be one of str, bool, float or "
"int: {} - {}".format(section, title))
if datatype in (float, int) and (rounding is None or min_max is None):
raise ValueError("'rounding' and 'min_max' must be set for numerical options")
if not isinstance(choices, (list, tuple)):
raise ValueError("'choices' must be a list or tuple")
self.defaults[section][title] = {"default": default,
"helptext": info,
"type": datatype,
"rounding": rounding,
"min_max": min_max,
"choices": choices}
def check_exists(self):
""" Check that a config file exists """
if not os.path.isfile(self.configfile):
logger.debug("Config file does not exist: '%s'", self.configfile)
return False
logger.debug("Config file exists: '%s'", self.configfile)
return True
def create_default(self):
""" Generate a default config if it does not exist """
logger.debug("Creating default Config")
for section, items in self.defaults.items():
logger.debug("Adding section: '%s')", section)
self.insert_config_section(section, items["helptext"])
for item, opt in items.items():
logger.debug("Adding option: (item: '%s', opt: '%s'", item, opt)
if item == "helptext":
continue
self.insert_config_item(section,
item,
opt["default"],
opt)
self.save_config()
def insert_config_section(self, section, helptext, config=None):
""" Insert a section into the config """
logger.debug("Inserting section: (section: '%s', helptext: '%s', config: '%s')",
section, helptext, config)
config = self.config if config is None else config
helptext = self.format_help(helptext, is_section=True)
config.add_section(section)
config.set(section, helptext)
logger.debug("Inserted section: '%s'", section)
def insert_config_item(self, section, item, default, option,
config=None):
""" Insert an item into a config section """
logger.debug("Inserting item: (section: '%s', item: '%s', default: '%s', helptext: '%s', "
"config: '%s')", section, item, default, option["helptext"], config)
config = self.config if config is None else config
helptext = option["helptext"]
helptext += self.set_helptext_choices(option)
helptext += "\n[Default: {}]".format(default)
helptext = self.format_help(helptext, is_section=False)
config.set(section, helptext)
config.set(section, item, str(default))
logger.debug("Inserted item: '%s'", item)
@staticmethod
def set_helptext_choices(option):
""" Set the helptext choices """
choices = ""
if option["choices"]:
choices = "\nChoose from: {}".format(option["choices"])
elif option["type"] == bool:
choices = "\nChoose from: True, False"
elif option["type"] == int:
cmin, cmax = option["min_max"]
choices = "\nSelect an integer between {} and {}".format(cmin, cmax)
elif option["type"] == float:
cmin, cmax = option["min_max"]
choices = "\nSelect a decimal number between {} and {}".format(cmin, cmax)
return choices
@staticmethod
def format_help(helptext, is_section=False):
""" Format comments for default ini file """
logger.debug("Formatting help: (helptext: '%s', is_section: '%s')", helptext, is_section)
helptext = '# {}'.format(helptext.replace("\n", "\n# "))
if is_section:
helptext = helptext.upper()
else:
helptext = "\n{}".format(helptext)
logger.debug("formatted help: '%s'", helptext)
return helptext
def load_config(self):
""" Load values from config """
logger.info("Loading config: '%s'", self.configfile)
self.config.read(self.configfile)
def save_config(self):
""" Save a config file """
logger.info("Updating config at: '%s'", self.configfile)
f_cfgfile = open(self.configfile, "w")
self.config.write(f_cfgfile)
f_cfgfile.close()
def validate_config(self):
""" Check for options in default config against saved config
and add/remove as appropriate """
logger.debug("Validating config")
if self.check_config_change():
self.add_new_config_items()
self.check_config_choices()
logger.debug("Validated config")
def add_new_config_items(self):
""" Add new items to the config file """
logger.debug("Updating config")
new_config = ConfigParser(allow_no_value=True)
for section, items in self.defaults.items():
self.insert_config_section(section, items["helptext"], new_config)
for item, opt in items.items():
if item == "helptext":
continue
if section not in self.config.sections():
logger.debug("Adding new config section: '%s'", section)
opt_value = opt["default"]
else:
opt_value = self.config[section].get(item, opt["default"])
self.insert_config_item(section,
item,
opt_value,
opt,
new_config)
self.config = new_config
self.config.optionxform = str
self.save_config()
logger.debug("Updated config")
def check_config_choices(self):
""" Check that config items are valid choices """
logger.debug("Checking config choices")
for section, items in self.defaults.items():
for item, opt in items.items():
if item == "helptext" or not opt["choices"]:
continue
opt_value = self.config.get(section, item)
if opt_value.lower() == "none" and any(choice.lower() == "none"
for choice in opt["choices"]):
continue
if opt_value not in opt["choices"]:
default = str(opt["default"])
logger.warning("'%s' is not a valid config choice for '%s': '%s'. Defaulting "
"to: '%s'", opt_value, section, item, default)
self.config.set(section, item, default)
logger.debug("Checked config choices")
def check_config_change(self):
""" Check whether new default items have been added or removed
from the config file compared to saved version """
if set(self.config.sections()) != set(self.defaults.keys()):
logger.debug("Default config has new section(s)")
return True
for section, items in self.defaults.items():
opts = [opt for opt in items.keys() if opt != "helptext"]
exists = [opt for opt in self.config[section].keys()
if not opt.startswith(("# ", "\n# "))]
if set(exists) != set(opts):
logger.debug("Default config has new item(s)")
return True
logger.debug("Default config has not changed")
return False
def handle_config(self):
""" Handle the config """
logger.debug("Handling config")
if not self.check_exists():
self.create_default()
self.load_config()
self.validate_config()
logger.debug("Handled config")

View file

@ -3,7 +3,7 @@
import logging import logging
from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module
from lib.aligner import Extract as AlignerExtract, get_align_mat from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -12,14 +12,13 @@ class DetectedFace():
""" Detected face and landmark information """ """ Detected face and landmark information """
def __init__( # pylint: disable=invalid-name def __init__( # pylint: disable=invalid-name
self, image=None, x=None, w=None, y=None, h=None, self, image=None, x=None, w=None, y=None, h=None,
frame_dims=None, landmarksXY=None): landmarksXY=None):
logger.trace("Initializing %s", self.__class__.__name__) logger.trace("Initializing %s", self.__class__.__name__)
self.image = image self.image = image
self.x = x self.x = x
self.w = w self.w = w
self.y = y self.y = y
self.h = h self.h = h
self.frame_dims = frame_dims
self.landmarksXY = landmarksXY self.landmarksXY = landmarksXY
self.hash = None # Hash must be set when the file is saved due to image compression self.hash = None # Hash must be set when the file is saved due to image compression
@ -63,17 +62,12 @@ class DetectedFace():
self.x: self.x + self.w] self.x: self.x + self.w]
def to_alignment(self): def to_alignment(self):
""" Convert a detected face to alignment dict """ Convert a detected face to alignment dict """
NB: frame_dims should be the height and width
of the original frame. """
alignment = dict() alignment = dict()
alignment["x"] = self.x alignment["x"] = self.x
alignment["w"] = self.w alignment["w"] = self.w
alignment["y"] = self.y alignment["y"] = self.y
alignment["h"] = self.h alignment["h"] = self.h
alignment["frame_dims"] = self.frame_dims
alignment["landmarksXY"] = self.landmarksXY alignment["landmarksXY"] = self.landmarksXY
alignment["hash"] = self.hash alignment["hash"] = self.hash
logger.trace("Returning: %s", alignment) logger.trace("Returning: %s", alignment)
@ -87,23 +81,22 @@ class DetectedFace():
self.w = alignment["w"] self.w = alignment["w"]
self.y = alignment["y"] self.y = alignment["y"]
self.h = alignment["h"] self.h = alignment["h"]
self.frame_dims = alignment["frame_dims"]
self.landmarksXY = alignment["landmarksXY"] self.landmarksXY = alignment["landmarksXY"]
# Manual tool does not know the final hash so default to None # Manual tool does not know the final hash so default to None
self.hash = alignment.get("hash", None) self.hash = alignment.get("hash", None)
if image is not None and image.any(): if image is not None and image.any():
self.image_to_face(image) self.image_to_face(image)
logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, " logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, "
"frame_dims: %s, landmarks: %s)", "landmarks: %s)",
self.x, self.w, self.y, self.h, self.frame_dims, self.landmarksXY) self.x, self.w, self.y, self.h, self.landmarksXY)
# <<< Aligned Face methods and properties >>> # # <<< Aligned Face methods and properties >>> #
def load_aligned(self, image, size=256, padding=48, align_eyes=False): def load_aligned(self, image, size=256, align_eyes=False):
""" No need to load aligned information for all uses of this """ No need to load aligned information for all uses of this
class, so only call this to load the information for easy class, so only call this to load the information for easy
reference to aligned properties for this face """ reference to aligned properties for this face """
logger.trace("Loading aligned face: (size: %s, padding: %s, align_eyes: %s)", logger.trace("Loading aligned face: (size: %s, align_eyes: %s)", size, align_eyes)
size, padding, align_eyes) padding = int(size * 0.1875)
self.aligned["size"] = size self.aligned["size"] = size
self.aligned["padding"] = padding self.aligned["padding"] = padding
self.aligned["align_eyes"] = align_eyes self.aligned["align_eyes"] = align_eyes
@ -153,3 +146,8 @@ class DetectedFace():
self.aligned["padding"]) self.aligned["padding"])
logger.trace("Returning: %s", mat) logger.trace("Returning: %s", mat)
return mat return mat
@property
def adjusted_interpolators(self):
""" Return the interpolator and reverse interpolator for the adjusted matrix """
return get_matrix_scaling(self.adjusted_matrix)

View file

@ -1,7 +1,9 @@
from lib.gui.command import CommandNotebook from lib.gui.command import CommandNotebook
from lib.gui.display import DisplayNotebook from lib.gui.display import DisplayNotebook
from lib.gui.options import CliOptions, Config from lib.gui.options import CliOptions
from lib.gui.stats import CurrentSession from lib.gui.menu import MainMenuBar
from lib.gui.popup_configure import popup_config
from lib.gui.stats import Session
from lib.gui.statusbar import StatusBar from lib.gui.statusbar import StatusBar
from lib.gui.utils import ConsoleOut, Images from lib.gui.utils import ConsoleOut, get_config, get_images, initialize_config, initialize_images
from lib.gui.wrapper import ProcessWrapper from lib.gui.wrapper import ProcessWrapper

View file

@ -5,44 +5,42 @@ import logging
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from .options import Config
from .tooltip import Tooltip from .tooltip import Tooltip
from .utils import ContextMenu, Images, FileHandler from .utils import ContextMenu, FileHandler, get_images, get_config, set_slider_rounding
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class CommandNotebook(ttk.Notebook): class CommandNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
""" Frame to hold each individual tab of the command notebook """ """ Frame to hold each individual tab of the command notebook """
def __init__(self, parent, cli_options, tk_vars, scaling_factor): def __init__(self, parent):
logger.debug("Initializing %s: (parent: %s, cli_options: %s, tk_vars: %s, " logger.debug("Initializing %s: (parent: %s)", self.__class__.__name__, parent)
"scaling_factor: %s", self.__class__.__name__, parent, cli_options, scaling_factor = get_config().scaling_factor
tk_vars, scaling_factor)
width = int(420 * scaling_factor) width = int(420 * scaling_factor)
height = int(500 * scaling_factor) height = int(500 * scaling_factor)
ttk.Notebook.__init__(self, parent, width=width, height=height) ttk.Notebook.__init__(self, parent, width=width, height=height)
parent.add(self) parent.add(self)
self.cli_opts = cli_options
self.tk_vars = tk_vars
self.actionbtns = dict() self.actionbtns = dict()
self.set_running_task_trace() self.set_running_task_trace()
self.build_tabs() self.build_tabs()
get_config().command_notebook = self
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def set_running_task_trace(self): def set_running_task_trace(self):
""" Set trigger action for the running task """ Set trigger action for the running task
to change the action buttons text and command """ to change the action buttons text and command """
logger.debug("Set running trace") logger.debug("Set running trace")
self.tk_vars["runningtask"].trace("w", self.change_action_button) tk_vars = get_config().tk_vars
tk_vars["runningtask"].trace("w", self.change_action_button)
def build_tabs(self): def build_tabs(self):
""" Build the tabs for the relevant command """ """ Build the tabs for the relevant command """
logger.debug("Build Tabs") logger.debug("Build Tabs")
for category in self.cli_opts.categories: cli_opts = get_config().cli_opts
cmdlist = self.cli_opts.commands[category] for category in cli_opts.categories:
cmdlist = cli_opts.commands[category]
for command in cmdlist: for command in cmdlist:
title = command.title() title = command.title()
commandtab = CommandTab(self, category, command) commandtab = CommandTab(self, category, command)
@ -52,9 +50,11 @@ class CommandNotebook(ttk.Notebook):
def change_action_button(self, *args): def change_action_button(self, *args):
""" Change the action button to relevant control """ """ Change the action button to relevant control """
logger.debug("Update Action Buttons: (args: %s", args) logger.debug("Update Action Buttons: (args: %s", args)
tk_vars = get_config().tk_vars
for cmd in self.actionbtns.keys(): for cmd in self.actionbtns.keys():
btnact = self.actionbtns[cmd] btnact = self.actionbtns[cmd]
if self.tk_vars["runningtask"].get(): if tk_vars["runningtask"].get():
ttl = "Terminate" ttl = "Terminate"
hlp = "Exit the running process" hlp = "Exit the running process"
else: else:
@ -65,7 +65,7 @@ class CommandNotebook(ttk.Notebook):
Tooltip(btnact, text=hlp, wraplength=200) Tooltip(btnact, text=hlp, wraplength=200)
class CommandTab(ttk.Frame): class CommandTab(ttk.Frame): # pylint: disable=too-many-ancestors
""" Frame to hold each individual tab of the command notebook """ """ Frame to hold each individual tab of the command notebook """
def __init__(self, parent, category, command): def __init__(self, parent, category, command):
@ -74,9 +74,7 @@ class CommandTab(ttk.Frame):
ttk.Frame.__init__(self, parent) ttk.Frame.__init__(self, parent)
self.category = category self.category = category
self.cli_opts = parent.cli_opts
self.actionbtns = parent.actionbtns self.actionbtns = parent.actionbtns
self.tk_vars = parent.tk_vars
self.command = command self.command = command
self.build_tab() self.build_tab()
@ -100,7 +98,7 @@ class CommandTab(ttk.Frame):
logger.debug("Added frame seperator") logger.debug("Added frame seperator")
class OptionsFrame(ttk.Frame): class OptionsFrame(ttk.Frame): # pylint: disable=too-many-ancestors
""" Options Frame - Holds the Options for each command """ """ Options Frame - Holds the Options for each command """
def __init__(self, parent): def __init__(self, parent):
@ -108,7 +106,6 @@ class OptionsFrame(ttk.Frame):
ttk.Frame.__init__(self, parent) ttk.Frame.__init__(self, parent)
self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) self.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
self.opts = parent.cli_opts
self.command = parent.command self.command = parent.command
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
@ -121,7 +118,8 @@ class OptionsFrame(ttk.Frame):
self.chkbtns = self.checkbuttons_frame() self.chkbtns = self.checkbuttons_frame()
self.build_frame() self.build_frame()
self.opts.set_context_option(self.command) cli_opts = get_config().cli_opts
cli_opts.set_context_option(self.command)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def checkbuttons_frame(self): def checkbuttons_frame(self):
@ -150,7 +148,8 @@ class OptionsFrame(ttk.Frame):
self.add_scrollbar() self.add_scrollbar()
self.canvas.bind("<Configure>", self.resize_frame) self.canvas.bind("<Configure>", self.resize_frame)
for option in self.opts.gen_command_options(self.command): cli_opts = get_config().cli_opts
for option in cli_opts.gen_command_options(self.command):
optioncontrol = OptionControl(self.command, optioncontrol = OptionControl(self.command,
option, option,
self.optsframe, self.optsframe,
@ -170,7 +169,7 @@ class OptionsFrame(ttk.Frame):
self.optsframe.bind("<Configure>", self.update_scrollbar) self.optsframe.bind("<Configure>", self.update_scrollbar)
logger.debug("Added Options Scrollbar") logger.debug("Added Options Scrollbar")
def update_scrollbar(self, event): def update_scrollbar(self, event): # pylint: disable=unused-argument
""" Update the options frame scrollbar """ """ Update the options frame scrollbar """
self.canvas.configure(scrollregion=self.canvas.bbox("all")) self.canvas.configure(scrollregion=self.canvas.bbox("all"))
@ -207,6 +206,7 @@ class OptionControl():
if ctl == ttk.Checkbutton: if ctl == ttk.Checkbutton:
dflt = self.option.get("default", False) dflt = self.option.get("default", False)
choices = self.option["choices"] if ctl == ttk.Combobox else None choices = self.option["choices"] if ctl == ttk.Combobox else None
min_max = self.option["min_max"] if ctl == ttk.Scale else None
ctlframe = self.build_one_control_frame() ctlframe = self.build_one_control_frame()
@ -217,6 +217,7 @@ class OptionControl():
self.option["value"] = self.build_one_control(ctlframe, self.option["value"] = self.build_one_control(ctlframe,
ctlvars, ctlvars,
choices, choices,
min_max,
sysbrowser) sysbrowser)
logger.debug("Built option control") logger.debug("Built option control")
@ -228,6 +229,7 @@ class OptionControl():
ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'") ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'")
else: else:
ctlhelp = " ".join(ctlhelp.split()) ctlhelp = " ".join(ctlhelp.split())
ctlhelp = ctlhelp.replace("%%", "%")
ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". ")) ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". "))
ctlhelp = ctltitle + " - " + ctlhelp ctlhelp = ctltitle + " - " + ctlhelp
logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp) logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp)
@ -249,15 +251,14 @@ class OptionControl():
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
logger.debug("Built control label: '%s'", control_title) logger.debug("Built control label: '%s'", control_title)
def build_one_control(self, frame, controlvars, choices, sysbrowser): def build_one_control(self, frame, controlvars, choices, min_max, sysbrowser):
""" Build and place the option controls """ """ Build and place the option controls """
logger.debug("Build control: (controlvars: %s, choices: %s, sysbrowser: %s", logger.debug("Build control: (controlvars: %s, choices: %s, min_max: %s, sysbrowser: %s",
controlvars, choices, sysbrowser) controlvars, choices, min_max, sysbrowser)
control, control_title, default, helptext = controlvars control, control_title, default, helptext = controlvars
default = default if default is not None else "" default = default if default is not None else ""
var = tk.BooleanVar( var = tk.BooleanVar(frame) if control == ttk.Checkbutton else tk.StringVar(frame)
frame) if control == ttk.Checkbutton else tk.StringVar(frame)
var.set(default) var.set(default)
if sysbrowser: if sysbrowser:
@ -268,6 +269,12 @@ class OptionControl():
control_title, control_title,
var, var,
helptext) helptext)
elif control == ttk.Scale:
self.slider_control(control,
frame,
var,
min_max,
helptext)
else: else:
self.control_to_optionsframe(control, self.control_to_optionsframe(control,
frame, frame,
@ -292,6 +299,29 @@ class OptionControl():
Tooltip(ctl, text=helptext, wraplength=200) Tooltip(ctl, text=helptext, wraplength=200)
logger.debug("Added control checkframe: '%s'", control_title) logger.debug("Added control checkframe: '%s'", control_title)
def slider_control(self, control, frame, tk_var, min_max, helptext):
""" A slider control with corresponding Entry box """
logger.debug("Add slider control to Options Frame: %s", control)
d_type = self.option.get("type", float)
rnd = self.option.get("rounding", 2) if d_type == float else self.option.get("rounding", 1)
tbox = ttk.Entry(frame, width=8, textvariable=tk_var, justify=tk.RIGHT)
tbox.pack(padx=(0, 5), side=tk.RIGHT)
ctl = control(
frame,
variable=tk_var,
command=lambda val, var=tk_var, dt=d_type, rn=rnd, mm=min_max:
set_slider_rounding(val, var, dt, rn, mm))
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
rc_menu = ContextMenu(ctl)
rc_menu.cm_bind()
ctl["from_"] = min_max[0]
ctl["to"] = min_max[1]
Tooltip(ctl, text=helptext, wraplength=720)
Tooltip(tbox, text=helptext, wraplength=720)
logger.debug("Added slider control to Options Frame: %s", control)
@staticmethod @staticmethod
def control_to_optionsframe(control, frame, var, choices, helptext): def control_to_optionsframe(control, frame, var, choices, helptext):
""" Standard non-check buttons sit in the main options frame """ """ Standard non-check buttons sit in the main options frame """
@ -303,8 +333,7 @@ class OptionControl():
if control == ttk.Combobox: if control == ttk.Combobox:
logger.debug("Adding combo choices: %s", choices) logger.debug("Adding combo choices: %s", choices)
ctl["values"] = [choice for choice in choices] ctl["values"] = [choice for choice in choices]
Tooltip(ctl, text=helptext, wraplength=920)
Tooltip(ctl, text=helptext, wraplength=720)
logger.debug("Added control to Options Frame: %s", control) logger.debug("Added control to Options Frame: %s", control)
def add_browser_buttons(self, frame, sysbrowser, filepath): def add_browser_buttons(self, frame, sysbrowser, filepath):
@ -312,7 +341,7 @@ class OptionControl():
logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'", logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'",
sysbrowser, filepath) sysbrowser, filepath)
for browser in sysbrowser: for browser in sysbrowser:
img = Images().icons[browser] img = get_images().icons[browser]
action = getattr(self, "ask_" + browser) action = getattr(self, "ask_" + browser)
filetypes = self.option.get("filetypes", "default") filetypes = self.option.get("filetypes", "default")
fileopn = ttk.Button(frame, fileopn = ttk.Button(frame,
@ -351,7 +380,7 @@ class OptionControl():
filepath.set(filename) filepath.set(filename)
@staticmethod @staticmethod
def ask_nothing(filepath, filetypes=None): def ask_nothing(filepath, filetypes=None): # pylint: disable=unused-argument
""" Method that does nothing, used for disabling open/save pop up """ """ Method that does nothing, used for disabling open/save pop up """
return return
@ -370,7 +399,7 @@ class OptionControl():
filepath.set(filename) filepath.set(filename)
class ActionFrame(ttk.Frame): class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors
"""Action Frame - Displays action controls for the command tab """ """Action Frame - Displays action controls for the command tab """
def __init__(self, parent): def __init__(self, parent):
@ -382,16 +411,16 @@ class ActionFrame(ttk.Frame):
self.title = self.command.title() self.title = self.command.title()
self.add_action_button(parent.category, self.add_action_button(parent.category,
parent.actionbtns, parent.actionbtns)
parent.tk_vars) self.add_util_buttons()
self.add_util_buttons(parent.cli_opts, parent.tk_vars)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def add_action_button(self, category, actionbtns, tk_vars): def add_action_button(self, category, actionbtns):
""" Add the action buttons for page """ """ Add the action buttons for page """
logger.debug("Add action buttons: '%s'", self.title) logger.debug("Add action buttons: '%s'", self.title)
actframe = ttk.Frame(self) actframe = ttk.Frame(self)
actframe.pack(fill=tk.X, side=tk.LEFT) actframe.pack(fill=tk.X, side=tk.LEFT)
tk_vars = get_config().tk_vars
var_value = "{},{}".format(category, self.command) var_value = "{},{}".format(category, self.command)
@ -415,17 +444,17 @@ class ActionFrame(ttk.Frame):
wraplength=200) wraplength=200)
logger.debug("Added action buttons: '%s'", self.title) logger.debug("Added action buttons: '%s'", self.title)
def add_util_buttons(self, cli_options, tk_vars): def add_util_buttons(self):
""" Add the section utility buttons """ """ Add the section utility buttons """
logger.debug("Add util buttons") logger.debug("Add util buttons")
utlframe = ttk.Frame(self) utlframe = ttk.Frame(self)
utlframe.pack(side=tk.RIGHT) utlframe.pack(side=tk.RIGHT)
config = Config(cli_options, tk_vars) config = get_config()
for utl in ("load", "save", "clear", "reset"): for utl in ("load", "save", "clear", "reset"):
logger.debug("Adding button: '%s'", utl) logger.debug("Adding button: '%s'", utl)
img = Images().icons[utl] img = get_images().icons[utl]
action_cls = config if utl in (("save", "load")) else cli_options action_cls = config if utl in (("save", "load")) else config.cli_opts
action = getattr(action_cls, utl) action = getattr(action_cls, utl)
btnutl = ttk.Button(utlframe, btnutl = ttk.Button(utlframe,
image=img, image=img,

View file

@ -4,48 +4,56 @@
What is displayed in the Display Frame varies What is displayed in the Display Frame varies
depending on what tasked is being run """ depending on what tasked is being run """
import logging
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from .display_analysis import Analysis from .display_analysis import Analysis
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
from .utils import get_config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class DisplayNotebook(ttk.Notebook): class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
""" The display tabs """ """ The display tabs """
def __init__(self, parent, session, tk_vars, scaling_factor): def __init__(self, parent):
logger.debug("Initializing %s", self.__class__.__name__)
ttk.Notebook.__init__(self, parent, width=780) ttk.Notebook.__init__(self, parent, width=780)
parent.add(self) parent.add(self)
tk_vars = get_config().tk_vars
self.wrapper_var = tk_vars["display"] self.wrapper_var = tk_vars["display"]
self.runningtask = tk_vars["runningtask"] self.runningtask = tk_vars["runningtask"]
self.session = session
self.set_wrapper_var_trace() self.set_wrapper_var_trace()
self.add_static_tabs(scaling_factor) self.add_static_tabs()
self.static_tabs = [child for child in self.tabs()] self.static_tabs = [child for child in self.tabs()]
logger.debug("Initialized %s", self.__class__.__name__)
def set_wrapper_var_trace(self): def set_wrapper_var_trace(self):
""" Set the trigger actions for the display vars """ Set the trigger actions for the display vars
when they have been triggered in the Process Wrapper """ when they have been triggered in the Process Wrapper """
logger.debug("Setting wrapper var trace")
self.wrapper_var.trace("w", self.update_displaybook) self.wrapper_var.trace("w", self.update_displaybook)
def add_static_tabs(self, scaling_factor): def add_static_tabs(self):
""" Add tabs that are permanently available """ """ Add tabs that are permanently available """
logger.debug("Adding static tabs")
for tab in ("job queue", "analysis"): for tab in ("job queue", "analysis"):
if tab == "job queue": if tab == "job queue":
continue # Not yet implemented continue # Not yet implemented
if tab == "analysis": if tab == "analysis":
helptext = {"stats": helptext = {"stats":
"Summary statistics for each training session"} "Summary statistics for each training session"}
frame = Analysis(self, tab, helptext, scaling_factor) frame = Analysis(self, tab, helptext)
else: else:
frame = self.add_frame() frame = self.add_frame()
self.add(frame, text=tab.title()) self.add(frame, text=tab.title())
def add_frame(self): def add_frame(self):
""" Add a single frame for holding tab's contents """ """ Add a single frame for holding tab's contents """
logger.debug("Adding frame")
frame = ttk.Frame(self) frame = ttk.Frame(self)
frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
return frame return frame
@ -58,32 +66,53 @@ class DisplayNotebook(ttk.Notebook):
def extract_tabs(self): def extract_tabs(self):
""" Build the extract tabs """ """ Build the extract tabs """
logger.debug("Build extract tabs")
helptext = ("Updates preview from output every 5 " helptext = ("Updates preview from output every 5 "
"seconds to limit disk contention") "seconds to limit disk contention")
PreviewExtract(self, "preview", helptext, 5000) PreviewExtract(self, "preview", helptext, 5000)
logger.debug("Built extract tabs")
def train_tabs(self): def train_tabs(self):
""" Build the train tabs """ """ Build the train tabs """
logger.debug("Build train tabs")
for tab in ("graph", "preview"): for tab in ("graph", "preview"):
if tab == "graph": if tab == "graph":
helptext = "Graph showing Loss vs Iterations" helptext = "Graph showing Loss vs Iterations"
GraphDisplay(self, "graph", helptext, 5000) GraphDisplay(self, "graph", helptext, 5000)
elif tab == "preview": elif tab == "preview":
helptext = "Training preview. Updated on every save iteration" helptext = "Training preview. Updated on every save iteration"
PreviewTrain(self, "preview", helptext, 5000) PreviewTrain(self, "preview", helptext, 1000)
logger.debug("Built train tabs")
def convert_tabs(self): def convert_tabs(self):
""" Build the convert tabs """ Build the convert tabs
Currently identical to Extract, so just call that """ Currently identical to Extract, so just call that """
logger.debug("Build convert tabs")
self.extract_tabs() self.extract_tabs()
logger.debug("Built convert tabs")
def remove_tabs(self): def remove_tabs(self):
""" Remove all command specific tabs """ """ Remove all command specific tabs """
for child in self.tabs(): for child in self.tabs():
if child not in self.static_tabs: if child in self.static_tabs:
self.forget(child) continue
logger.debug("removing child: %s", child)
child_name = child.split(".")[-1]
child_object = self.children[child_name]
self.destroy_tabs_children(child_object)
self.forget(child)
def update_displaybook(self, *args): @staticmethod
def destroy_tabs_children(tab):
""" Destroy all tabs children
Children must be destroyed as forget only hides display
"""
logger.debug("Destroying children for tab: %s", tab)
for child in tab.winfo_children():
logger.debug("Destroying child: %s", child)
child.destroy()
def update_displaybook(self, *args): # pylint: disable=unused-argument
""" Set the display tabs based on executing task """ """ Set the display tabs based on executing task """
command = self.wrapper_var.get() command = self.wrapper_var.get()
self.remove_tabs() self.remove_tabs()

View file

@ -2,104 +2,138 @@
""" Analysis tab of Display Frame of the Faceswap GUI """ """ Analysis tab of Display Frame of the Faceswap GUI """
import csv import csv
import logging
import os
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from .display_graph import SessionGraph from .display_graph import SessionGraph
from .display_page import DisplayPage from .display_page import DisplayPage
from .stats import Calculations, SavedSessions, SessionsSummary, SessionsTotals from .stats import Calculations, Session
from .tooltip import Tooltip from .tooltip import Tooltip
from .utils import Images, FileHandler from .utils import FileHandler, get_config, get_images
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Analysis(DisplayPage): # pylint: disable=too-many-ancestors class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
""" Session analysis tab """ """ Session analysis tab """
def __init__(self, parent, tabname, helptext, scaling_factor): def __init__(self, parent, tabname, helptext):
DisplayPage.__init__(self, parent, tabname, helptext) logger.debug("Initializing: %s: (parent, %s, tabname: '%s', helptext: '%s')",
self.__class__.__name__, parent, tabname, helptext)
super().__init__(parent, tabname, helptext)
self.summary = None self.summary = None
self.session = None
self.add_options() self.add_options()
self.add_main_frame(scaling_factor) self.add_main_frame()
logger.debug("Initialized: %s", self.__class__.__name__)
def set_vars(self): def set_vars(self):
""" Analysis specific vars """ """ Analysis specific vars """
selected_id = tk.StringVar() selected_id = tk.StringVar()
filename = tk.StringVar() return {"selected_id": selected_id}
return {"selected_id": selected_id,
"filename": filename}
def add_main_frame(self, scaling_factor): def add_main_frame(self):
""" Add the main frame to the subnotebook """ Add the main frame to the subnotebook
to hold stats and session data """ to hold stats and session data """
logger.debug("Adding main frame")
mainframe = self.subnotebook_add_page("stats") mainframe = self.subnotebook_add_page("stats")
self.stats = StatsData(mainframe, self.stats = StatsData(mainframe,
self.vars["filename"],
self.vars["selected_id"], self.vars["selected_id"],
self.helptext["stats"], self.helptext["stats"])
scaling_factor) logger.debug("Added main frame")
def add_options(self): def add_options(self):
""" Add the options bar """ """ Add the options bar """
logger.debug("Adding options")
self.reset_session_info() self.reset_session_info()
options = Options(self) options = Options(self)
options.add_options() options.add_options()
logger.debug("Added options")
def reset_session_info(self): def reset_session_info(self):
""" Reset the session info status to default """ """ Reset the session info status to default """
self.vars["filename"].set(None) logger.debug("Resetting session info")
self.set_info("No session data loaded") self.set_info("No session data loaded")
def load_session(self): def load_session(self):
""" Load previously saved sessions """ """ Load previously saved sessions """
logger.debug("Loading session")
self.clear_session() self.clear_session()
filename = FileHandler("open", "session").retfile fullpath = FileHandler("filename", "state").retfile
if not filename: if not fullpath:
return return
filename = filename.name logger.debug("state_file: '%s'", fullpath)
loaded_data = SavedSessions(filename).sessions model_dir, state_file = os.path.split(fullpath)
msg = filename logger.debug("model_dir: '%s'", model_dir)
if len(filename) > 70: model_name = self.get_model_name(model_dir, state_file)
msg = "...{}".format(filename[-70:]) if not model_name:
self.set_session_summary(loaded_data, msg) return
self.vars["filename"].set(filename) self.session = Session(model_dir=model_dir, model_name=model_name)
self.session.initialize_session(is_training=False)
msg = os.path.split(state_file)[0]
if len(msg) > 70:
msg = "...{}".format(msg[-70:])
self.set_session_summary(msg)
@staticmethod
def get_model_name(model_dir, state_file):
""" Get the state file from the model directory """
logger.debug("Getting model name")
model_name = state_file.replace("_state.json", "")
logger.debug("model_name: %s", model_name)
logs_dir = os.path.join(model_dir, "{}_logs".format(model_name))
if not os.path.isdir(logs_dir):
logger.warning("No logs folder found in folder: '%s'", logs_dir)
return None
return model_name
def reset_session(self): def reset_session(self):
""" Load previously saved sessions """ """ Reset currently training sessions """
logger.debug("Reset current training session")
self.clear_session() self.clear_session()
if self.session.stats["iterations"] == 0: session = get_config().session
if not session.initialized:
logger.debug("Training not running")
print("Training not running") print("Training not running")
return return
loaded_data = self.session.historical.sessions
msg = "Currently running training session" msg = "Currently running training session"
self.set_session_summary(loaded_data, msg) self.session = session
self.vars["filename"].set("Currently running training session") self.set_session_summary(msg)
def set_session_summary(self, data, message): def set_session_summary(self, message):
""" Set the summary data and info message """ """ Set the summary data and info message """
self.summary = SessionsSummary(data).summary logger.debug("Setting session summary. (message: '%s')", message)
self.summary = self.session.full_summary
self.set_info("Session: {}".format(message)) self.set_info("Session: {}".format(message))
self.stats.loaded_data = data self.stats.session = self.session
self.stats.tree_insert_data(self.summary) self.stats.tree_insert_data(self.summary)
def clear_session(self): def clear_session(self):
""" Clear sessions stats """ """ Clear sessions stats """
logger.debug("Clearing session")
self.summary = None self.summary = None
self.stats.loaded_data = None self.stats.session = None
self.stats.tree_clear() self.stats.tree_clear()
self.reset_session_info() self.reset_session_info()
def save_session(self): def save_session(self):
""" Save sessions stats to csv """ """ Save sessions stats to csv """
logger.debug("Saving session")
if not self.summary: if not self.summary:
logger.debug("No summary data loaded. Nothing to save")
print("No summary data loaded. Nothing to save") print("No summary data loaded. Nothing to save")
return return
savefile = FileHandler("save", "csv").retfile savefile = FileHandler("save", "csv").retfile
if not savefile: if not savefile:
logger.debug("No save file. Returning")
return return
write_dicts = [val for val in self.summary.values()] write_dicts = [val for val in self.summary.values()]
fieldnames = sorted(key for key in write_dicts[0].keys()) fieldnames = sorted(key for key in write_dicts[0].keys())
logger.debug("Saving to: '%s'", savefile)
with savefile as outfile: with savefile as outfile:
csvout = csv.DictWriter(outfile, fieldnames) csvout = csv.DictWriter(outfile, fieldnames)
csvout.writeheader() csvout.writeheader()
@ -110,8 +144,10 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
class Options(): class Options():
""" Options bar of Analysis tab """ """ Options bar of Analysis tab """
def __init__(self, parent): def __init__(self, parent):
logger.debug("Initializing: %s", self.__class__.__name__)
self.optsframe = parent.optsframe self.optsframe = parent.optsframe
self.parent = parent self.parent = parent
logger.debug("Initialized: %s", self.__class__.__name__)
def add_options(self): def add_options(self):
""" Add the display tab options """ """ Add the display tab options """
@ -120,9 +156,10 @@ class Options():
def add_buttons(self): def add_buttons(self):
""" Add the option buttons """ """ Add the option buttons """
for btntype in ("reset", "clear", "save", "load"): for btntype in ("reset", "clear", "save", "load"):
logger.debug("Adding button: '%s'", btntype)
cmd = getattr(self.parent, "{}_session".format(btntype)) cmd = getattr(self.parent, "{}_session".format(btntype))
btn = ttk.Button(self.optsframe, btn = ttk.Button(self.optsframe,
image=Images().icons[btntype], image=get_images().icons[btntype],
command=cmd) command=cmd)
btn.pack(padx=2, side=tk.RIGHT) btn.pack(padx=2, side=tk.RIGHT)
hlp = self.set_help(btntype) hlp = self.set_help(btntype)
@ -131,6 +168,7 @@ class Options():
@staticmethod @staticmethod
def set_help(btntype): def set_help(btntype):
""" Set the helptext for option buttons """ """ Set the helptext for option buttons """
logger.debug("Setting help")
hlp = "" hlp = ""
if btntype == "reset": if btntype == "reset":
hlp = "Load/Refresh stats for the currently training session" hlp = "Load/Refresh stats for the currently training session"
@ -145,25 +183,15 @@ class Options():
class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
""" Stats frame of analysis tab """ """ Stats frame of analysis tab """
def __init__(self, def __init__(self, parent, selected_id, helptext):
parent, logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')",
filename, self.__class__.__name__, parent, selected_id, helptext)
selected_id, super().__init__(parent)
helptext, self.pack(side=tk.TOP, padx=5, pady=5, expand=True, fill=tk.X, anchor=tk.N)
scaling_factor):
ttk.Frame.__init__(self, parent)
self.pack(side=tk.TOP,
padx=5,
pady=5,
expand=True,
fill=tk.X,
anchor=tk.N)
self.filename = filename self.session = None # set when loading or clearing from parent
self.loaded_data = None
self.selected_id = selected_id self.selected_id = selected_id
self.popup_positions = list() self.popup_positions = list()
self.scaling_factor = scaling_factor
self.add_label() self.add_label()
self.tree = ttk.Treeview(self, height=1, selectmode=tk.BROWSE) self.tree = ttk.Treeview(self, height=1, selectmode=tk.BROWSE)
@ -171,14 +199,17 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
orient="vertical", orient="vertical",
command=self.tree.yview) command=self.tree.yview)
self.columns = self.tree_configure(helptext) self.columns = self.tree_configure(helptext)
logger.debug("Initialized: %s", self.__class__.__name__)
def add_label(self): def add_label(self):
""" Add Treeview Title """ """ Add Treeview Title """
logger.debug("Adding Treeview title")
lbl = ttk.Label(self, text="Session Stats", anchor=tk.CENTER) lbl = ttk.Label(self, text="Session Stats", anchor=tk.CENTER)
lbl.pack(side=tk.TOP, expand=True, fill=tk.X, padx=5, pady=5) lbl.pack(side=tk.TOP, expand=True, fill=tk.X, padx=5, pady=5)
def tree_configure(self, helptext): def tree_configure(self, helptext):
""" Build a treeview widget to hold the sessions stats """ """ Build a treeview widget to hold the sessions stats """
logger.debug("Configuring Treeview")
self.tree.configure(yscrollcommand=self.scrollbar.set) self.tree.configure(yscrollcommand=self.scrollbar.set)
self.tree.tag_configure("total", self.tree.tag_configure("total",
background="black", background="black",
@ -191,6 +222,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
def tree_columns(self): def tree_columns(self):
""" Add the columns to the totals treeview """ """ Add the columns to the totals treeview """
logger.debug("Adding Treeview columns")
columns = (("session", 40, "#"), columns = (("session", 40, "#"),
("start", 130, None), ("start", 130, None),
("end", 130, None), ("end", 130, None),
@ -202,6 +234,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
for column in columns: for column in columns:
text = column[2] if column[2] else column[0].title() text = column[2] if column[2] else column[0].title()
logger.debug("Adding heading: '%s'", text)
self.tree.heading(column[0], text=text) self.tree.heading(column[0], text=text)
self.tree.column(column[0], self.tree.column(column[0],
width=column[1], width=column[1],
@ -212,19 +245,21 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
return [column[0] for column in columns] return [column[0] for column in columns]
def tree_insert_data(self, sessions): def tree_insert_data(self, sessions_summary):
""" Insert the data into the totals treeview """ """ Insert the data into the totals treeview """
self.tree.configure(height=len(sessions)) logger.debug("Inserting treeview data")
self.tree.configure(height=len(sessions_summary))
for item in sessions: for item in sessions_summary:
values = [item[column] for column in self.columns] values = [item[column] for column in self.columns]
kwargs = {"values": values, "image": Images().icons["graph"]} kwargs = {"values": values, "image": get_images().icons["graph"]}
if values[0] == "Total": if values[0] == "Total":
kwargs["tags"] = "total" kwargs["tags"] = "total"
self.tree.insert("", "end", **kwargs) self.tree.insert("", "end", **kwargs)
def tree_clear(self): def tree_clear(self):
""" Clear the totals tree """ """ Clear the totals tree """
logger.debug("Clearing treeview data")
self.tree.delete(* self.tree.get_children()) self.tree.delete(* self.tree.get_children())
self.tree.configure(height=1) self.tree.configure(height=1)
@ -235,17 +270,22 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
selection = self.tree.focus() selection = self.tree.focus()
values = self.tree.item(selection, "values") values = self.tree.item(selection, "values")
if values: if values:
logger.debug("Selected values: %s", values)
self.selected_id.set(values[0]) self.selected_id.set(values[0])
if region == "tree": if region == "tree":
self.data_popup() self.data_popup()
def data_popup(self): def data_popup(self):
""" Pop up a window and control it's position """ """ Pop up a window and control it's position """
toplevel = SessionPopUp(self.loaded_data, self.selected_id.get()) logger.debug("Popping up data window")
scaling_factor = get_config().scaling_factor
toplevel = SessionPopUp(self.session.modeldir,
self.session.modelname,
self.selected_id.get())
toplevel.title(self.data_popup_title()) toplevel.title(self.data_popup_title())
position = self.data_popup_get_position() position = self.data_popup_get_position()
height = int(720 * self.scaling_factor) height = int(720 * scaling_factor)
width = int(400 * self.scaling_factor) width = int(400 * scaling_factor)
toplevel.geometry("{}x{}+{}+{}".format(str(height), toplevel.geometry("{}x{}+{}+{}".format(str(height),
str(width), str(width),
str(position[0]), str(position[0]),
@ -254,14 +294,17 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
def data_popup_title(self): def data_popup_title(self):
""" Set the data popup title """ """ Set the data popup title """
logger.debug("Setting poup title")
selected_id = self.selected_id.get() selected_id = self.selected_id.get()
title = "All Sessions" title = "All Sessions"
if selected_id != "Total": if selected_id != "Total":
title = "Session #{}".format(selected_id) title = "{} Model: Session #{}".format(self.session.modelname.title(), selected_id)
return "{} - {}".format(title, self.filename.get()) logger.debug("Title: '%s'", title)
return "{} - {}".format(title, self.session.modeldir)
def data_popup_get_position(self): def data_popup_get_position(self):
""" Get the position of the next window """ """ Get the position of the next window """
logger.debug("getting poup position")
init_pos = [120, 120] init_pos = [120, 120]
pos = init_pos pos = init_pos
while True: while True:
@ -270,25 +313,33 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
break break
pos = [item + 200 for item in pos] pos = [item + 200 for item in pos]
init_pos, pos = self.data_popup_check_boundaries(init_pos, pos) init_pos, pos = self.data_popup_check_boundaries(init_pos, pos)
logger.debug("Position: %s", pos)
return pos return pos
def data_popup_check_boundaries(self, initial_position, position): def data_popup_check_boundaries(self, initial_position, position):
""" Check that the popup remains within the screen boundaries """ """ Check that the popup remains within the screen boundaries """
logger.debug("Checking poup boundaries: (initial_position: %s, position: %s)",
initial_position, position)
boundary_x = self.winfo_screenwidth() - 120 boundary_x = self.winfo_screenwidth() - 120
boundary_y = self.winfo_screenheight() - 120 boundary_y = self.winfo_screenheight() - 120
if position[0] >= boundary_x or position[1] >= boundary_y: if position[0] >= boundary_x or position[1] >= boundary_y:
initial_position = [initial_position[0] + 50, initial_position[1]] initial_position = [initial_position[0] + 50, initial_position[1]]
position = initial_position position = initial_position
logger.debug("Returning poup boundaries: (initial_position: %s, position: %s)",
initial_position, position)
return initial_position, position return initial_position, position
class SessionPopUp(tk.Toplevel): class SessionPopUp(tk.Toplevel):
""" Pop up for detailed grap/stats for selected session """ """ Pop up for detailed graph/stats for selected session """
def __init__(self, data, session_id): def __init__(self, model_dir, model_name, session_id):
tk.Toplevel.__init__(self) logger.debug("Initializing: %s: (model_dir: %s, model_name: %s, session_id: %s)",
self.__class__.__name__, model_dir, model_name, session_id)
super().__init__()
self.is_totals = session_id == "Total" self.session_id = session_id
self.data = self.set_session_data(data, session_id) self.session = Session(model_dir=model_dir, model_name=model_name)
self.initialize_session()
self.graph = None self.graph = None
self.display_data = None self.display_data = None
@ -296,25 +347,35 @@ class SessionPopUp(tk.Toplevel):
self.vars = dict() self.vars = dict()
self.graph_initialised = False self.graph_initialised = False
self.build() self.build()
logger.debug("Initialized: %s", self.__class__.__name__)
def set_session_data(self, sessions, session_id): @property
""" Set the correct list index based on the passed in session is """ def is_totals(self):
if self.is_totals: """ Return True if these are totals else False """
data = SessionsTotals(sessions).stats return bool(self.session_id == "Total")
else:
data = sessions[int(session_id) - 1] def initialize_session(self):
return data """ Initialize the session """
logger.debug("Initializing session")
kwargs = dict(is_training=False)
if not self.is_totals:
kwargs["session_id"] = int(self.session_id)
logger.debug("Session kwargs: %s", kwargs)
self.session.initialize_session(**kwargs)
def build(self): def build(self):
""" Build the popup window """ """ Build the popup window """
logger.debug("Building popup")
optsframe, graphframe = self.layout_frames() optsframe, graphframe = self.layout_frames()
self.opts_build(optsframe) self.opts_build(optsframe)
self.compile_display_data() self.compile_display_data()
self.graph_build(graphframe) self.graph_build(graphframe)
logger.debug("Built popup")
def layout_frames(self): def layout_frames(self):
""" Top level container frames """ """ Top level container frames """
logger.debug("Layout frames")
leftframe = ttk.Frame(self) leftframe = ttk.Frame(self)
leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5) leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5)
@ -323,20 +384,25 @@ class SessionPopUp(tk.Toplevel):
rightframe = ttk.Frame(self) rightframe = ttk.Frame(self)
rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True) rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True)
logger.debug("Laid out frames")
return leftframe, rightframe return leftframe, rightframe
def opts_build(self, frame): def opts_build(self, frame):
""" Options in options to the optsframe """ """ Build Options into the options frame """
logger.debug("Building Options")
self.opts_combobox(frame) self.opts_combobox(frame)
self.opts_checkbuttons(frame) self.opts_checkbuttons(frame)
self.opts_loss_keys(frame)
self.opts_entry(frame) self.opts_entry(frame)
self.opts_buttons(frame) self.opts_buttons(frame)
sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) sep = ttk.Frame(frame, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
logger.debug("Built Options")
def opts_combobox(self, frame): def opts_combobox(self, frame):
""" Add the options combo boxes """ """ Add the options combo boxes """
logger.debug("Building Combo boxes")
choices = {"Display": ("Loss", "Rate"), choices = {"Display": ("Loss", "Rate"),
"Scale": ("Linear", "Log")} "Scale": ("Linear", "Log")}
@ -362,9 +428,11 @@ class SessionPopUp(tk.Toplevel):
hlp = self.set_help(item) hlp = self.set_help(item)
Tooltip(cmbframe, text=hlp, wraplength=200) Tooltip(cmbframe, text=hlp, wraplength=200)
logger.debug("Built Combo boxes")
def opts_checkbuttons(self, frame): def opts_checkbuttons(self, frame):
""" Add the options check buttons """ """ Add the options check buttons """
logger.debug("Building Check Buttons")
for item in ("raw", "trend", "avg", "outliers"): for item in ("raw", "trend", "avg", "outliers"):
if item == "avg": if item == "avg":
text = "Show Rolling Average" text = "Show Rolling Average"
@ -384,9 +452,35 @@ class SessionPopUp(tk.Toplevel):
hlp = self.set_help(item) hlp = self.set_help(item)
Tooltip(ctl, text=hlp, wraplength=200) Tooltip(ctl, text=hlp, wraplength=200)
logger.debug("Built Check Buttons")
def opts_loss_keys(self, frame):
""" Add loss key selections """
logger.debug("Building Loss Key Check Buttons")
loss_keys = self.session.loss_keys
lk_vars = dict()
for loss_key in sorted(loss_keys):
text = loss_key.replace("_", " ").title()
helptext = "Display {}".format(text)
var = tk.BooleanVar()
var.set(True)
var.trace("w", self.optbtn_reset)
lk_vars[loss_key] = var
if len(loss_keys) == 1:
# Don't display if there's only one item
break
ctl = ttk.Checkbutton(frame, variable=var, text=text)
ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W)
Tooltip(ctl, text=helptext, wraplength=200)
self.vars["loss_keys"] = lk_vars
logger.debug("Built Loss Key Check Buttons")
def opts_entry(self, frame): def opts_entry(self, frame):
""" Add the options entry boxes """ """ Add the options entry boxes """
logger.debug("Building Entry Boxes")
for item in ("avgiterations", ): for item in ("avgiterations", ):
if item == "avgiterations": if item == "avgiterations":
text = "Iterations to Average:" text = "Iterations to Average:"
@ -405,27 +499,32 @@ class SessionPopUp(tk.Toplevel):
Tooltip(entframe, text=hlp, wraplength=200) Tooltip(entframe, text=hlp, wraplength=200)
self.vars[item] = ctl self.vars[item] = ctl
logger.debug("Built Entry Boxes")
def opts_buttons(self, frame): def opts_buttons(self, frame):
""" Add the option buttons """ """ Add the option buttons """
logger.debug("Building Buttons")
btnframe = ttk.Frame(frame) btnframe = ttk.Frame(frame)
btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM) btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM)
for btntype in ("reset", "save"): for btntype in ("reset", "save"):
cmd = getattr(self, "optbtn_{}".format(btntype)) cmd = getattr(self, "optbtn_{}".format(btntype))
btn = ttk.Button(btnframe, btn = ttk.Button(btnframe,
image=Images().icons[btntype], image=get_images().icons[btntype],
command=cmd) command=cmd)
btn.pack(padx=2, side=tk.RIGHT) btn.pack(padx=2, side=tk.RIGHT)
hlp = self.set_help(btntype) hlp = self.set_help(btntype)
Tooltip(btn, text=hlp, wraplength=200) Tooltip(btn, text=hlp, wraplength=200)
logger.debug("Built Buttons")
def optbtn_save(self): def optbtn_save(self):
""" Action for save button press """ """ Action for save button press """
logger.debug("Saving File")
savefile = FileHandler("save", "csv").retfile savefile = FileHandler("save", "csv").retfile
if not savefile: if not savefile:
logger.debug("Save Cancelled")
return return
logger.debug("Saving to: %s", savefile)
save_data = self.display_data.stats save_data = self.display_data.stats
fieldnames = sorted(key for key in save_data.keys()) fieldnames = sorted(key for key in save_data.keys())
@ -434,16 +533,21 @@ class SessionPopUp(tk.Toplevel):
csvout.writerow(fieldnames) csvout.writerow(fieldnames)
csvout.writerows(zip(*[save_data[key] for key in fieldnames])) csvout.writerows(zip(*[save_data[key] for key in fieldnames]))
def optbtn_reset(self, *args): def optbtn_reset(self, *args): # pylint: disable=unused-argument
""" Action for reset button press and checkbox changes""" """ Action for reset button press and checkbox changes"""
logger.debug("Refreshing Graph")
if not self.graph_initialised: if not self.graph_initialised:
return return
self.compile_display_data() valid = self.compile_display_data()
if not valid:
logger.debug("Invalid data")
return
self.graph.refresh(self.display_data, self.graph.refresh(self.display_data,
self.vars["display"].get(), self.vars["display"].get(),
self.vars["scale"].get()) self.vars["scale"].get())
logger.debug("Refreshed Graph")
def graph_scale(self, *args): def graph_scale(self, *args): # pylint: disable=unused-argument
""" Action for changing graph scale """ """ Action for changing graph scale """
if not self.graph_initialised: if not self.graph_initialised:
return return
@ -477,25 +581,53 @@ class SessionPopUp(tk.Toplevel):
def compile_display_data(self): def compile_display_data(self):
""" Compile the data to be displayed """ """ Compile the data to be displayed """
self.display_data = Calculations(self.data, logger.debug("Compiling Display Data")
self.vars["display"].get(),
self.selections_to_list(), loss_keys = [key for key, val in self.vars["loss_keys"].items()
self.vars["avgiterations"].get(), if val.get()]
self.vars["outliers"].get(), logger.debug("Selected loss_keys: %s", loss_keys)
self.is_totals)
selections = self.selections_to_list()
if not self.check_valid_selection(loss_keys, selections):
return False
self.display_data = Calculations(session=self.session,
display=self.vars["display"].get(),
loss_keys=loss_keys,
selections=selections,
avg_samples=self.vars["avgiterations"].get(),
flatten_outliers=self.vars["outliers"].get(),
is_totals=self.is_totals)
logger.debug("Compiled Display Data")
return True
def check_valid_selection(self, loss_keys, selections):
""" Check that there will be data to display """
display = self.vars["display"].get().lower()
logger.debug("Validating selection. (loss_keys: %s, selections: %s, display: %s)",
loss_keys, selections, display)
if not selections or (display == "loss" and not loss_keys):
msg = "No data to display. Not refreshing"
logger.debug(msg)
print(msg)
return False
return True
def selections_to_list(self): def selections_to_list(self):
""" Compile checkbox selections to list """ """ Compile checkbox selections to list """
logger.debug("Compiling selections to list")
selections = list() selections = list()
for key, val in self.vars.items(): for key, val in self.vars.items():
if (isinstance(val, tk.BooleanVar) if (isinstance(val, tk.BooleanVar)
and key != "outliers" and key != "outliers"
and val.get()): and val.get()):
selections.append(key) selections.append(key)
logger.debug("Compiling selections to list: %s", selections)
return selections return selections
def graph_build(self, frame): def graph_build(self, frame):
""" Build the graph in the top right paned window """ """ Build the graph in the top right paned window """
logger.debug("Building Graph")
self.graph = SessionGraph(frame, self.graph = SessionGraph(frame,
self.display_data, self.display_data,
self.vars["display"].get(), self.vars["display"].get(),
@ -503,3 +635,4 @@ class SessionPopUp(tk.Toplevel):
self.graph.pack(expand=True, fill=tk.BOTH) self.graph.pack(expand=True, fill=tk.BOTH)
self.graph.build() self.graph.build()
self.graph_initialised = True self.graph_initialised = True
logger.debug("Built Graph")

View file

@ -1,6 +1,7 @@
#!/usr/bin python3 #!/usr/bin python3
""" Command specific tabs of Display Frame of the Faceswap GUI """ """ Command specific tabs of Display Frame of the Faceswap GUI """
import datetime import datetime
import logging
import os import os
import tkinter as tk import tkinter as tk
@ -11,19 +12,23 @@ from .display_graph import TrainingGraph
from .display_page import DisplayOptionalPage from .display_page import DisplayOptionalPage
from .tooltip import Tooltip from .tooltip import Tooltip
from .stats import Calculations from .stats import Calculations
from .utils import Images, FileHandler from .utils import FileHandler, get_config, get_images
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class PreviewExtract(DisplayOptionalPage): class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Tab to display output preview images for extract and convert """ """ Tab to display output preview images for extract and convert """
def display_item_set(self): def display_item_set(self):
""" Load the latest preview if available """ """ Load the latest preview if available """
Images().load_latest_preview() logger.trace("Loading latest preview")
self.display_item = Images().previewoutput get_images().load_latest_preview()
self.display_item = get_images().previewoutput
def display_item_process(self): def display_item_process(self):
""" Display the preview """ """ Display the preview """
logger.trace("Displaying preview")
if not self.subnotebook.children: if not self.subnotebook.children:
self.add_child() self.add_child()
else: else:
@ -31,15 +36,17 @@ class PreviewExtract(DisplayOptionalPage):
def add_child(self): def add_child(self):
""" Add the preview label child """ """ Add the preview label child """
logger.debug("Adding child")
preview = self.subnotebook_add_page(self.tabname, widget=None) preview = self.subnotebook_add_page(self.tabname, widget=None)
lblpreview = ttk.Label(preview, image=Images().previewoutput[1]) lblpreview = ttk.Label(preview, image=get_images().previewoutput[1])
lblpreview.pack(side=tk.TOP, anchor=tk.NW) lblpreview.pack(side=tk.TOP, anchor=tk.NW)
Tooltip(lblpreview, text=self.helptext, wraplength=200) Tooltip(lblpreview, text=self.helptext, wraplength=200)
def update_child(self): def update_child(self):
""" Update the preview image on the label """ """ Update the preview image on the label """
logger.trace("Updating preview")
for widget in self.subnotebook_get_widgets(): for widget in self.subnotebook_get_widgets():
widget.configure(image=Images().previewoutput[1]) widget.configure(image=get_images().previewoutput[1])
def save_items(self): def save_items(self):
""" Open save dialogue and save preview """ """ Open save dialogue and save preview """
@ -52,41 +59,56 @@ class PreviewExtract(DisplayOptionalPage):
"{}_{}.{}".format(filename, "{}_{}.{}".format(filename,
now, now,
"png")) "png"))
Images().previewoutput[0].save(filename) get_images().previewoutput[0].save(filename)
logger.debug("Saved preview to %s", filename)
print("Saved preview to {}".format(filename)) print("Saved preview to {}".format(filename))
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Training preview image(s) """ """ Training preview image(s) """
def __init__(self, *args, **kwargs):
self.update_preview = get_config().tk_vars["updatepreview"]
super().__init__(*args, **kwargs)
def display_item_set(self): def display_item_set(self):
""" Load the latest preview if available """ """ Load the latest preview if available """
Images().load_training_preview() logger.trace("Loading latest preview")
self.display_item = Images().previewtrain if not self.update_preview.get():
logger.trace("Preview not updated")
return
get_images().load_training_preview()
self.display_item = get_images().previewtrain
def display_item_process(self): def display_item_process(self):
""" Display the preview(s) resized as appropriate """ """ Display the preview(s) resized as appropriate """
sortednames = sorted([name for name in Images().previewtrain.keys()]) logger.trace("Displaying preview")
sortednames = sorted(list(get_images().previewtrain.keys()))
existing = self.subnotebook_get_titles_ids() existing = self.subnotebook_get_titles_ids()
should_update = self.update_preview.get()
for name in sortednames: for name in sortednames:
if name not in existing.keys(): if name not in existing.keys():
self.add_child(name) self.add_child(name)
else: elif should_update:
tab_id = existing[name] tab_id = existing[name]
self.update_child(tab_id, name) self.update_child(tab_id, name)
if should_update:
self.update_preview.set(False)
def add_child(self, name): def add_child(self, name):
""" Add the preview canvas child """ """ Add the preview canvas child """
logger.debug("Adding child")
preview = PreviewTrainCanvas(self.subnotebook, name) preview = PreviewTrainCanvas(self.subnotebook, name)
preview = self.subnotebook_add_page(name, widget=preview) preview = self.subnotebook_add_page(name, widget=preview)
Tooltip(preview, text=self.helptext, wraplength=200) Tooltip(preview, text=self.helptext, wraplength=200)
self.vars["modified"].set(Images().previewtrain[name][2]) self.vars["modified"].set(get_images().previewtrain[name][2])
def update_child(self, tab_id, name): def update_child(self, tab_id, name):
""" Update the preview canvas """ """ Update the preview canvas """
if self.vars["modified"].get() != Images().previewtrain[name][2]: logger.debug("Updating preview")
self.vars["modified"].set(Images().previewtrain[name][2]) if self.vars["modified"].get() != get_images().previewtrain[name][2]:
self.vars["modified"].set(get_images().previewtrain[name][2])
widget = self.subnotebook_page_from_id(tab_id) widget = self.subnotebook_page_from_id(tab_id)
widget.reload() widget.reload()
@ -102,11 +124,12 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
""" Canvas to hold a training preview image """ """ Canvas to hold a training preview image """
def __init__(self, parent, previewname): def __init__(self, parent, previewname):
logger.debug("Initializing %s: (previewname: '%s')", self.__class__.__name__, previewname)
ttk.Frame.__init__(self, parent) ttk.Frame.__init__(self, parent)
self.name = previewname self.name = previewname
Images().resize_image(self.name, None) get_images().resize_image(self.name, None)
self.previewimage = Images().previewtrain[self.name][1] self.previewimage = get_images().previewtrain[self.name][1]
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
@ -115,18 +138,21 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
image=self.previewimage, image=self.previewimage,
anchor=tk.NW) anchor=tk.NW)
self.bind("<Configure>", self.resize) self.bind("<Configure>", self.resize)
logger.debug("Initialized %s:", self.__class__.__name__)
def resize(self, event): def resize(self, event):
""" Resize the image to fit the frame, maintaining aspect ratio """ """ Resize the image to fit the frame, maintaining aspect ratio """
logger.trace("Resizing preview image")
framesize = (event.width, event.height) framesize = (event.width, event.height)
# Sometimes image is resized before frame is drawn # Sometimes image is resized before frame is drawn
framesize = None if framesize == (1, 1) else framesize framesize = None if framesize == (1, 1) else framesize
Images().resize_image(self.name, framesize) get_images().resize_image(self.name, framesize)
self.reload() self.reload()
def reload(self): def reload(self):
""" Reload the preview image """ """ Reload the preview image """
self.previewimage = Images().previewtrain[self.name][1] logger.trace("Reloading preview image")
self.previewimage = get_images().previewtrain[self.name][1]
self.canvas.itemconfig(self.imgcanvas, image=self.previewimage) self.canvas.itemconfig(self.imgcanvas, image=self.previewimage)
def save_preview(self, location): def save_preview(self, location):
@ -137,40 +163,63 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
"{}_{}.{}".format(filename, "{}_{}.{}".format(filename,
now, now,
"png")) "png"))
Images().previewtrain[self.name][0].save(filename) get_images().previewtrain[self.name][0].save(filename)
logger.debug("Saved preview to %s", filename)
print("Saved preview to {}".format(filename)) print("Saved preview to {}".format(filename))
class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" The Graph Tab of the Display section """ """ The Graph Tab of the Display section """
def add_options(self):
""" Add the additional options """
self.add_option_refresh()
super().add_options()
def add_option_refresh(self):
""" Add refresh button to refresh graph immediately """
logger.debug("Adding refresh option")
tk_var = get_config().tk_vars["refreshgraph"]
btnrefresh = ttk.Button(self.optsframe,
image=get_images().icons["reset"],
command=lambda: tk_var.set(True))
btnrefresh.pack(padx=2, side=tk.RIGHT)
Tooltip(btnrefresh,
text="Graph updates every 100 iterations. Click to refresh now.",
wraplength=200)
def display_item_set(self): def display_item_set(self):
""" Load the graph(s) if available """ """ Load the graph(s) if available """
if self.session.stats["iterations"] == 0: session = get_config().session
if session.initialized and session.logging_disabled:
logger.trace("Logs disabled. Hiding graph")
self.set_info("Graph is disabled as 'no-logs' has been selected")
self.display_item = None self.display_item = None
elif session.initialized:
logger.trace("Loading graph")
self.display_item = session
else: else:
self.display_item = self.session.stats self.display_item = None
def display_item_process(self): def display_item_process(self):
""" Add a single graph to the graph window """ """ Add a single graph to the graph window """
losskeys = self.display_item["losskeys"] logger.trace("Adding graph")
loss = self.display_item["loss"] existing = list(self.subnotebook_get_titles_ids().keys())
tabcount = int(len(losskeys) / 2)
existing = self.subnotebook_get_titles_ids() for loss_key in self.display_item.loss_keys:
for i in range(tabcount): tabname = loss_key.replace("_", " ").title()
selectedkeys = losskeys[i * 2:(i + 1) * 2] if tabname in existing:
name = " - ".join(selectedkeys).title().replace("_", " ") continue
if name not in existing.keys():
selectedloss = loss[i * 2:(i + 1) * 2] data = Calculations(session=get_config().session,
selection = {"loss": selectedloss, display="loss",
"losskeys": selectedkeys} loss_keys=[loss_key],
data = Calculations(session=selection, selections=["raw", "trend"])
display="loss", self.add_child(tabname, data)
selections=["raw", "trend"])
self.add_child(name, data)
def add_child(self, name, data): def add_child(self, name, data):
""" Add the graph for the selected keys """ """ Add the graph for the selected keys """
logger.debug("Adding child: %s", name)
graph = TrainingGraph(self.subnotebook, data, "Loss") graph = TrainingGraph(self.subnotebook, data, "Loss")
graph.build() graph.build()
graph = self.subnotebook_add_page(name, widget=graph) graph = self.subnotebook_add_page(name, widget=graph)

View file

@ -1,6 +1,7 @@
#!/usr/bin python3 #!/usr/bin python3
""" Graph functions for Display Frame of the Faceswap GUI """ """ Graph functions for Display Frame of the Faceswap GUI """
import datetime import datetime
import logging
import os import os
import tkinter as tk import tkinter as tk
@ -8,14 +9,17 @@ from tkinter import ttk
from math import ceil, floor from math import ceil, floor
import matplotlib import matplotlib
# pylint: disable=wrong-import-position
matplotlib.use("TkAgg") matplotlib.use("TkAgg")
import matplotlib.animation as animation
from matplotlib import pyplot as plt, style
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
NavigationToolbar2Tk)
from .tooltip import Tooltip from matplotlib import pyplot as plt, style # noqa
from .utils import Images from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
NavigationToolbar2Tk) # noqa
from .tooltip import Tooltip # noqa
from .utils import get_config, get_images # noqa
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors
@ -26,28 +30,24 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
t[0] in ("Home", "Pan", "Zoom", "Save")] t[0] in ("Home", "Pan", "Zoom", "Save")]
@staticmethod @staticmethod
def _Button(frame, text, file, command, extension=".gif"): def _Button(frame, text, file, command, extension=".gif"): # pylint: disable=arguments-differ
""" Map Buttons to their own frame. """ Map Buttons to their own frame.
Use custom button icons, Use custom button icons, Use ttk buttons pack to the right """
Use ttk buttons
pack to the right """
iconmapping = {"home": "reset", iconmapping = {"home": "reset",
"filesave": "save", "filesave": "save",
"zoom_to_rect": "zoom"} "zoom_to_rect": "zoom"}
icon = iconmapping[file] if iconmapping.get(file, None) else file icon = iconmapping[file] if iconmapping.get(file, None) else file
img = Images().icons[icon] img = get_images().icons[icon]
btn = ttk.Button(frame, text=text, image=img, command=command) btn = ttk.Button(frame, text=text, image=img, command=command)
btn.pack(side=tk.RIGHT, padx=2) btn.pack(side=tk.RIGHT, padx=2)
return btn return btn
def _init_toolbar(self): def _init_toolbar(self):
""" Same as original but ttk widgets and standard """ Same as original but ttk widgets and standard tooltips used. Separator added and
tooltips used. Separator added and message label message label packed to the left """
packed to the left """
xmin, xmax = self.canvas.figure.bbox.intervalx xmin, xmax = self.canvas.figure.bbox.intervalx
height, width = 50, xmax-xmin height, width = 50, xmax-xmin
ttk.Frame.__init__(self, master=self.window, ttk.Frame.__init__(self, master=self.window, width=int(width), height=int(height))
width=int(width), height=int(height))
sep = ttk.Frame(self, height=2, relief=tk.RIDGE) sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP)
@ -76,14 +76,14 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
""" Base class for matplotlib line graphs """ """ Base class for matplotlib line graphs """
def __init__(self, parent, data, ylabel): def __init__(self, parent, data, ylabel):
ttk.Frame.__init__(self, parent) logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent)
style.use("ggplot") style.use("ggplot")
self.calcs = data self.calcs = data
self.ylabel = ylabel self.ylabel = ylabel
self.colourmaps = ["Reds", "Blues", "Greens", self.colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges",
"Purples", "Oranges", "Greys", "Greys", "copper", "summer", "bone"]
"copper", "summer", "bone"]
self.lines = list() self.lines = list()
self.toolbar = None self.toolbar = None
self.fig = plt.figure(figsize=(4, 4), dpi=75) self.fig = plt.figure(figsize=(4, 4), dpi=75)
@ -92,26 +92,29 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self.initiate_graph() self.initiate_graph()
self.update_plot(initiate=True) self.update_plot(initiate=True)
logger.debug("Initialized %s", self.__class__.__name__)
def initiate_graph(self): def initiate_graph(self):
""" Place the graph canvas """ """ Place the graph canvas """
self.plotcanvas.get_tk_widget().pack(side=tk.TOP, logger.debug("Setting plotcanvas")
padx=5, self.plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True)
fill=tk.BOTH,
expand=True)
plt.subplots_adjust(left=0.100, plt.subplots_adjust(left=0.100,
bottom=0.100, bottom=0.100,
right=0.95, right=0.95,
top=0.95, top=0.95,
wspace=0.2, wspace=0.2,
hspace=0.2) hspace=0.2)
logger.debug("Set plotcanvas")
def update_plot(self, initiate=True): def update_plot(self, initiate=True):
""" Update the plot with incoming data """ """ Update the plot with incoming data """
logger.trace("Updating plot")
if initiate: if initiate:
logger.debug("Initializing plot")
self.lines = list() self.lines = list()
self.ax1.clear() self.ax1.clear()
self.axes_labels_set() self.axes_labels_set()
logger.debug("Initialized plot")
fulldata = [item for item in self.calcs.stats.values()] fulldata = [item for item in self.calcs.stats.values()]
self.axes_limits_set(fulldata) self.axes_limits_set(fulldata)
@ -120,37 +123,37 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
keys = list(self.calcs.stats.keys()) keys = list(self.calcs.stats.keys())
for idx, item in enumerate(self.lines_sort(keys)): for idx, item in enumerate(self.lines_sort(keys)):
if initiate: if initiate:
self.lines.extend(self.ax1.plot(xrng, self.lines.extend(self.ax1.plot(xrng, self.calcs.stats[item[0]],
self.calcs.stats[item[0]], label=item[1], linewidth=item[2], color=item[3]))
label=item[1],
linewidth=item[2],
color=item[3]))
else: else:
self.lines[idx].set_data(xrng, self.calcs.stats[item[0]]) self.lines[idx].set_data(xrng, self.calcs.stats[item[0]])
if initiate: if initiate:
self.legend_place() self.legend_place()
logger.trace("Updated plot")
def axes_labels_set(self): def axes_labels_set(self):
""" Set the axes label and range """ """ Set the axes label and range """
logger.debug("Setting axes labels. y-label: '%s'", self.ylabel)
self.ax1.set_xlabel("Iterations") self.ax1.set_xlabel("Iterations")
self.ax1.set_ylabel(self.ylabel) self.ax1.set_ylabel(self.ylabel)
def axes_limits_set_default(self): def axes_limits_set_default(self):
""" Set default axes limits """ """ Set default axes limits """
logger.debug("Setting default axes ranges")
self.ax1.set_ylim(0.00, 100.0) self.ax1.set_ylim(0.00, 100.0)
self.ax1.set_xlim(0, 1) self.ax1.set_xlim(0, 1)
def axes_limits_set(self, data): def axes_limits_set(self, data):
""" Set the axes limits """ """ Set the axes limits """
xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1 xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1
if data: if data:
ymin, ymax = self.axes_data_get_min_max(data) ymin, ymax = self.axes_data_get_min_max(data)
self.ax1.set_ylim(ymin, ymax) self.ax1.set_ylim(ymin, ymax)
self.ax1.set_xlim(0, xmax) self.ax1.set_xlim(0, xmax)
else: else:
self.axes_limits_set_default() self.axes_limits_set_default()
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax)
@staticmethod @staticmethod
def axes_data_get_min_max(data): def axes_data_get_min_max(data):
@ -164,15 +167,18 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
ymax.append(max(dataset) * 1000) ymax.append(max(dataset) * 1000)
ymin = floor(min(ymin)) / 1000 ymin = floor(min(ymin)) / 1000
ymax = ceil(max(ymax)) / 1000 ymax = ceil(max(ymax)) / 1000
logger.trace("ymin: %s, ymax: %s", ymin, ymax)
return ymin, ymax return ymin, ymax
def axes_set_yscale(self, scale): def axes_set_yscale(self, scale):
""" Set the Y-Scale to log or linear """ """ Set the Y-Scale to log or linear """
logger.debug("yscale: '%s'", scale)
self.ax1.set_yscale(scale) self.ax1.set_yscale(scale)
def lines_sort(self, keys): def lines_sort(self, keys):
""" Sort the data keys into consistent order """ Sort the data keys into consistent order
and set line colourmap and line width """ and set line color map and line width """
logger.trace("Sorting lines")
raw_lines = list() raw_lines = list()
sorted_lines = list() sorted_lines = list()
for key in sorted(keys): for key in sorted(keys):
@ -184,29 +190,28 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
groupsize = self.lines_groupsize(raw_lines, sorted_lines) groupsize = self.lines_groupsize(raw_lines, sorted_lines)
sorted_lines = raw_lines + sorted_lines sorted_lines = raw_lines + sorted_lines
lines = self.lines_style(sorted_lines, groupsize) lines = self.lines_style(sorted_lines, groupsize)
return lines return lines
@staticmethod @staticmethod
def lines_groupsize(raw_lines, sorted_lines): def lines_groupsize(raw_lines, sorted_lines):
""" Get the number of items in each group. """ Get the number of items in each group.
If raw data isn't selected, then check If raw data isn't selected, then check the length of
the length of remaining groups until remaining groups until something is found """
something is found """
groupsize = 1 groupsize = 1
if raw_lines: if raw_lines:
groupsize = len(raw_lines) groupsize = len(raw_lines)
else: else:
for check in ("avg", "trend"): for check in ("avg", "trend"):
if any(item[0].startswith(check) for item in sorted_lines): if any(item[0].startswith(check) for item in sorted_lines):
groupsize = len([item for item in sorted_lines groupsize = len([item for item in sorted_lines if item[0].startswith(check)])
if item[0].startswith(check)])
break break
logger.trace(groupsize)
return groupsize return groupsize
def lines_style(self, lines, groupsize): def lines_style(self, lines, groupsize):
""" Set the colourmap and linewidth for each group """ """ Set the color map and line width for each group """
logger.trace("Setting lines style")
groups = int(len(lines) / groupsize) groups = int(len(lines) / groupsize)
colours = self.lines_create_colors(groupsize, groups) colours = self.lines_create_colors(groupsize, groups)
for idx, item in enumerate(lines): for idx, item in enumerate(lines):
@ -215,21 +220,24 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return lines return lines
def lines_create_colors(self, groupsize, groups): def lines_create_colors(self, groupsize, groups):
""" Create the colours """ """ Create the colors """
colours = list() colours = list()
for i in range(1, groups + 1): for i in range(1, groups + 1):
for colour in self.colourmaps[0:groupsize]: for colour in self.colourmaps[0:groupsize]:
cmap = matplotlib.cm.get_cmap(colour) cmap = matplotlib.cm.get_cmap(colour)
cpoint = 1 - (i / 5) cpoint = 1 - (i / 5)
colours.append(cmap(cpoint)) colours.append(cmap(cpoint))
logger.trace(colours)
return colours return colours
def legend_place(self): def legend_place(self):
""" Place and format legend """ """ Place and format legend """
logger.debug("Placing legend")
self.ax1.legend(loc="upper right", ncol=2) self.ax1.legend(loc="upper right", ncol=2)
def toolbar_place(self, parent): def toolbar_place(self, parent):
""" Add Graph Navigation toolbar """ """ Add Graph Navigation toolbar """
logger.debug("Placing toolbar")
self.toolbar = NavigationToolbar(self.plotcanvas, parent) self.toolbar = NavigationToolbar(self.plotcanvas, parent)
self.toolbar.pack(side=tk.BOTTOM) self.toolbar.pack(side=tk.BOTTOM)
self.toolbar.update() self.toolbar.update()
@ -240,72 +248,48 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def __init__(self, parent, data, ylabel): def __init__(self, parent, data, ylabel):
GraphBase.__init__(self, parent, data, ylabel) GraphBase.__init__(self, parent, data, ylabel)
self.add_callback()
self.anim = None def add_callback(self):
""" Add the variable trace to update graph on recent button or save iteration """
get_config().tk_vars["refreshgraph"].trace("w", self.refresh)
def build(self): def build(self):
""" Update the plot area with loss values and cycle through to """ Update the plot area with loss values """
animate """ logger.debug("Building training graph")
self.anim = animation.FuncAnimation(self.fig,
self.animate,
interval=200,
blit=False)
self.plotcanvas.draw() self.plotcanvas.draw()
logger.debug("Built training graph")
def animate(self, i): def refresh(self, *args): # pylint: disable=unused-argument
""" Read loss data and apply to graph """ """ Read loss data and apply to graph """
logger.debug("Updating plot")
self.calcs.refresh() self.calcs.refresh()
self.update_plot(initiate=False) self.update_plot(initiate=False)
self.plotcanvas.draw()
def set_animation_rate(self, iterations): get_config().tk_vars["refreshgraph"].set(False)
""" Change the animation update interval based on how
many iterations have been
There's no point calculating a graph over thousands of
points of data when the change will be miniscule """
if iterations > 30000:
speed = 60000 # 1 min updates
elif iterations > 20000:
speed = 30000 # 30 sec updates
elif iterations > 10000:
speed = 10000 # 10 sec updates
elif iterations > 5000:
speed = 5000 # 5 sec updates
elif iterations > 1000:
speed = 2000 # 2 sec updates
elif iterations > 500:
speed = 1000 # 1 sec updates
elif iterations > 100:
speed = 500 # 0.5 sec updates
else:
speed = 200 # 200ms updates
if not self.anim.event_source.interval == speed:
self.anim.event_source.interval = speed
def save_fig(self, location): def save_fig(self, location):
""" Save the figure to file """ """ Save the figure to file """
keys = sorted([key.replace("raw_", "") logger.debug("Saving graph: '%s'", location)
for key in self.calcs.stats.keys() keys = sorted([key.replace("raw_", "") for key in self.calcs.stats.keys()
if key.startswith("raw_")]) if key.startswith("raw_")])
filename = " - ".join(keys) filename = " - ".join(keys)
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(location, filename = os.path.join(location, "{}_{}.{}".format(filename, now, "png"))
"{}_{}.{}".format(filename,
now,
"png"))
self.fig.set_size_inches(16, 9) self.fig.set_size_inches(16, 9)
self.fig.savefig(filename, bbox_inches="tight", dpi=120) self.fig.savefig(filename, bbox_inches="tight", dpi=120)
print("Saved graph to {}".format(filename)) print("Saved graph to {}".format(filename))
logger.debug("Saved graph: '%s'", filename)
self.resize_fig() self.resize_fig()
def resize_fig(self): def resize_fig(self):
""" Resize the figure back to the canvas """ """ Resize the figure back to the canvas """
class Event(): class Event(): # pylint: disable=too-few-public-methods
""" Event class that needs to be passed to """ Event class that needs to be passed to plotcanvas.resize """
plotcanvas.resize """
pass pass
Event.width = self.winfo_width() Event.width = self.winfo_width()
Event.height = self.winfo_height() Event.height = self.winfo_height()
self.plotcanvas.resize(Event) self.plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
@ -316,18 +300,24 @@ class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
def build(self): def build(self):
""" Build the session graph """ """ Build the session graph """
logger.debug("Building session graph")
self.toolbar_place(self) self.toolbar_place(self)
self.plotcanvas.draw() self.plotcanvas.draw()
logger.debug("Built session graph")
def refresh(self, data, ylabel, scale): def refresh(self, data, ylabel, scale):
""" Refresh graph data """ """ Refresh graph data """
logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale)
self.calcs = data self.calcs = data
self.ylabel = ylabel self.ylabel = ylabel
self.set_yscale_type(scale) self.set_yscale_type(scale)
logger.debug("Refreshed session graph")
def set_yscale_type(self, scale): def set_yscale_type(self, scale):
""" switch the y-scale and redraw """ """ switch the y-scale and redraw """
logger.debug("Updating scale type: '%s'", scale)
self.scale = scale self.scale = scale
self.update_plot(initiate=True) self.update_plot(initiate=True)
self.axes_set_yscale(self.scale) self.axes_set_yscale(self.scale)
self.plotcanvas.draw() self.plotcanvas.draw()
logger.debug("Updated scale type")

View file

@ -1,21 +1,25 @@
#!/usr/bin python3 #!/usr/bin python3
""" Display Page parent classes for display section of the Faceswap GUI """ """ Display Page parent classes for display section of the Faceswap GUI """
import logging
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from .tooltip import Tooltip from .tooltip import Tooltip
from .utils import Images from .utils import get_images
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class DisplayPage(ttk.Frame): class DisplayPage(ttk.Frame):
""" Parent frame holder for each tab. """ Parent frame holder for each tab.
Defines uniform structure for each tab to inherit from """ Defines uniform structure for each tab to inherit from """
def __init__(self, parent, tabname, helptext): def __init__(self, parent, tabname, helptext):
logger.debug("Initializing %s: (tabname: '%s', helptext: %s",
self.__class__.__name__, tabname, helptext)
ttk.Frame.__init__(self, parent) ttk.Frame.__init__(self, parent)
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW) self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
self.session = parent.session
self.runningtask = parent.runningtask self.runningtask = parent.runningtask
self.helptext = helptext self.helptext = helptext
self.tabname = tabname self.tabname = tabname
@ -30,32 +34,37 @@ class DisplayPage(ttk.Frame):
self.add_frame_separator() self.add_frame_separator()
self.set_mainframe_single_tab_style() self.set_mainframe_single_tab_style()
parent.add(self, text=self.tabname.title()) parent.add(self, text=self.tabname.title())
logger.debug("Initialized %s", self.__class__.__name__,)
def add_optional_vars(self, varsdict): def add_optional_vars(self, varsdict):
""" Add page specific variables """ """ Add page specific variables """
if isinstance(varsdict, dict): if isinstance(varsdict, dict):
for key, val in varsdict.items(): for key, val in varsdict.items():
logger.debug("Adding: (%s: %s)", key, val)
self.vars[key] = val self.vars[key] = val
@staticmethod @staticmethod
def set_vars(): def set_vars():
""" Overide to return a dict of page specific variables """ """ Override to return a dict of page specific variables """
return dict() return dict()
def add_subnotebook(self): def add_subnotebook(self):
""" Add the main frame notebook """ """ Add the main frame notebook """
logger.debug("Adding subnotebook")
notebook = ttk.Notebook(self) notebook = ttk.Notebook(self)
notebook.pack(side=tk.TOP, anchor=tk.NW, fill=tk.BOTH, expand=True) notebook.pack(side=tk.TOP, anchor=tk.NW, fill=tk.BOTH, expand=True)
return notebook return notebook
def add_options_frame(self): def add_options_frame(self):
""" Add the display tab options """ """ Add the display tab options """
logger.debug("Adding options frame")
optsframe = ttk.Frame(self) optsframe = ttk.Frame(self)
optsframe.pack(side=tk.BOTTOM, padx=5, pady=5, fill=tk.X) optsframe.pack(side=tk.BOTTOM, padx=5, pady=5, fill=tk.X)
return optsframe return optsframe
def add_options_info(self): def add_options_info(self):
""" Add the info bar """ """ Add the info bar """
logger.debug("Adding options info")
lblinfo = ttk.Label(self.optsframe, lblinfo = ttk.Label(self.optsframe,
textvariable=self.vars["info"], textvariable=self.vars["info"],
anchor=tk.W, anchor=tk.W,
@ -64,22 +73,26 @@ class DisplayPage(ttk.Frame):
def set_info(self, msg): def set_info(self, msg):
""" Set the info message """ """ Set the info message """
logger.debug("Setting info: %s", msg)
self.vars["info"].set(msg) self.vars["info"].set(msg)
def add_frame_separator(self): def add_frame_separator(self):
""" Add a separator between top and bottom frames """ """ Add a separator between top and bottom frames """
logger.debug("Adding frame seperator")
sep = ttk.Frame(self, height=2, relief=tk.RIDGE) sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
@staticmethod @staticmethod
def set_mainframe_single_tab_style(): def set_mainframe_single_tab_style():
""" Configure ttk notebook style to represent a single frame """ """ Configure ttk notebook style to represent a single frame """
logger.debug("Setting main frame single tab style")
nbstyle = ttk.Style() nbstyle = ttk.Style()
nbstyle.configure("single.TNotebook", borderwidth=0) nbstyle.configure("single.TNotebook", borderwidth=0)
nbstyle.layout("single.TNotebook.Tab", []) nbstyle.layout("single.TNotebook.Tab", [])
def subnotebook_add_page(self, tabtitle, widget=None): def subnotebook_add_page(self, tabtitle, widget=None):
""" Add a page to the sub notebook """ """ Add a page to the sub notebook """
logger.debug("Adding subnotebook page: %s", tabtitle)
frame = widget if widget else ttk.Frame(self.subnotebook) frame = widget if widget else ttk.Frame(self.subnotebook)
frame.pack(padx=5, pady=5, fill=tk.BOTH, expand=True) frame.pack(padx=5, pady=5, fill=tk.BOTH, expand=True)
self.subnotebook.add(frame, text=tabtitle) self.subnotebook.add(frame, text=tabtitle)
@ -89,28 +102,32 @@ class DisplayPage(ttk.Frame):
def subnotebook_configure(self): def subnotebook_configure(self):
""" Configure notebook to display or hide tabs """ """ Configure notebook to display or hide tabs """
if len(self.subnotebook.children) == 1: if len(self.subnotebook.children) == 1:
logger.debug("Setting single page style")
self.subnotebook.configure(style="single.TNotebook") self.subnotebook.configure(style="single.TNotebook")
else: else:
logger.debug("Setting multi page style")
self.subnotebook.configure(style="TNotebook") self.subnotebook.configure(style="TNotebook")
def subnotebook_hide(self): def subnotebook_hide(self):
""" Hide the subnotebook. Used for hiding """ Hide the subnotebook. Used for hiding
Optional displays """ Optional displays """
if self.subnotebook.winfo_ismapped(): if self.subnotebook and self.subnotebook.winfo_ismapped():
logger.debug("Hiding subnotebook")
self.subnotebook.pack_forget() self.subnotebook.pack_forget()
self.subnotebook.destroy()
self.subnotebook = None
def subnotebook_show(self): def subnotebook_show(self):
""" Show subnotebook. Used for displaying """ Show subnotebook. Used for displaying
Optional displays """ Optional displays """
if not self.subnotebook.winfo_ismapped(): if not self.subnotebook:
self.subnotebook.pack(side=tk.TOP, logger.debug("Showing subnotebook")
anchor=tk.NW, self.subnotebook = self.add_subnotebook()
fill=tk.BOTH,
expand=True)
def subnotebook_get_widgets(self): def subnotebook_get_widgets(self):
""" Return each widget that sits within each """ Return each widget that sits within each
subnotebook frame """ subnotebook frame """
logger.debug("Getting subnotebook widgets")
for child in self.subnotebook.winfo_children(): for child in self.subnotebook.winfo_children():
for widget in child.winfo_children(): for widget in child.winfo_children():
yield widget yield widget
@ -120,11 +137,13 @@ class DisplayPage(ttk.Frame):
tabs = dict() tabs = dict()
for tab_id in range(0, self.subnotebook.index("end")): for tab_id in range(0, self.subnotebook.index("end")):
tabs[self.subnotebook.tab(tab_id, "text")] = tab_id tabs[self.subnotebook.tab(tab_id, "text")] = tab_id
logger.debug(tabs)
return tabs return tabs
def subnotebook_page_from_id(self, tab_id): def subnotebook_page_from_id(self, tab_id):
""" Return subnotebook tab widget from it's ID """ """ Return subnotebook tab widget from it's ID """
tab_name = self.subnotebook.tabs()[tab_id].split(".")[-1] tab_name = self.subnotebook.tabs()[tab_id].split(".")[-1]
logger.debug(tab_name)
return self.subnotebook.children[tab_name] return self.subnotebook.children[tab_name]
@ -155,19 +174,23 @@ class DisplayOptionalPage(DisplayPage):
modified = tk.DoubleVar() modified = tk.DoubleVar()
modified.set(None) modified.set(None)
return {"enabled": enabled, tk_vars = {"enabled": enabled,
"ready": ready, "ready": ready,
"modified": modified} "modified": modified}
logger.debug(tk_vars)
return tk_vars
# INFO LABEL # INFO LABEL
def set_info_text(self): def set_info_text(self):
""" Set waiting for display text """ """ Set waiting for display text """
if not self.vars["enabled"].get(): if not self.vars["enabled"].get():
self.set_info("{} disabled".format(self.tabname.title())) msg = "{} disabled".format(self.tabname.title())
elif self.vars["enabled"].get() and not self.vars["ready"].get(): elif self.vars["enabled"].get() and not self.vars["ready"].get():
self.set_info("Waiting for {}...".format(self.tabname)) msg = "Waiting for {}...".format(self.tabname)
else: else:
self.set_info("Displaying {}".format(self.tabname)) msg = "Displaying {}".format(self.tabname)
logger.debug(msg)
self.set_info(msg)
# DISPLAY OPTIONS BAR # DISPLAY OPTIONS BAR
def add_options(self): def add_options(self):
@ -177,8 +200,9 @@ class DisplayOptionalPage(DisplayPage):
def add_option_save(self): def add_option_save(self):
""" Add save button to save page output to file """ """ Add save button to save page output to file """
logger.debug("Adding save option")
btnsave = ttk.Button(self.optsframe, btnsave = ttk.Button(self.optsframe,
image=Images().icons["save"], image=get_images().icons["save"],
command=self.save_items) command=self.save_items)
btnsave.pack(padx=2, side=tk.RIGHT) btnsave.pack(padx=2, side=tk.RIGHT)
Tooltip(btnsave, Tooltip(btnsave,
@ -187,6 +211,7 @@ class DisplayOptionalPage(DisplayPage):
def add_option_enable(self): def add_option_enable(self):
""" Add checkbutton to enable/disable page """ """ Add checkbutton to enable/disable page """
logger.debug("Adding enable option")
chkenable = ttk.Checkbutton(self.optsframe, chkenable = ttk.Checkbutton(self.optsframe,
variable=self.vars["enabled"], variable=self.vars["enabled"],
text="Enable {}".format(self.tabname), text="Enable {}".format(self.tabname),
@ -202,6 +227,7 @@ class DisplayOptionalPage(DisplayPage):
def on_chkenable_change(self): def on_chkenable_change(self):
""" Update the display immediately on a checkbutton change """ """ Update the display immediately on a checkbutton change """
logger.debug("Enabled checkbox changed")
if self.vars["enabled"].get(): if self.vars["enabled"].get():
self.subnotebook_show() self.subnotebook_show()
else: else:
@ -213,6 +239,7 @@ class DisplayOptionalPage(DisplayPage):
if not self.runningtask.get(): if not self.runningtask.get():
return return
if self.vars["enabled"].get(): if self.vars["enabled"].get():
logger.trace("Updating page")
self.display_item_set() self.display_item_set()
self.load_display() self.load_display()
self.after(waittime, lambda t=waittime: self.update_page(t)) self.after(waittime, lambda t=waittime: self.update_page(t))
@ -225,6 +252,7 @@ class DisplayOptionalPage(DisplayPage):
""" Load the display """ """ Load the display """
if not self.display_item: if not self.display_item:
return return
logger.debug("Loading display")
self.display_item_process() self.display_item_process()
self.vars["ready"].set(True) self.vars["ready"].set(True)
self.set_info_text() self.set_info_text()

134
lib/gui/menu.py Normal file
View file

@ -0,0 +1,134 @@
#!/usr/bin python3
""" The Menu Bars for faceswap GUI """
import logging
import os
import sys
import tkinter as tk
from importlib import import_module
from lib.Serializer import JSONSerializer
from .utils import get_config
from .popup_configure import popup_config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class MainMenuBar(tk.Menu):
""" GUI Main Menu Bar """
def __init__(self, master=None):
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(master)
self.root = master
self.config = get_config()
self.file_menu = tk.Menu(self, tearoff=0)
self.recent_menu = tk.Menu(self.file_menu, tearoff=0, postcommand=self.refresh_recent_menu)
self.edit_menu = tk.Menu(self, tearoff=0)
self.build_file_menu()
self.build_edit_menu()
logger.debug("Initialized %s", self.__class__.__name__)
def build_file_menu(self):
""" Add the file menu to the menu bar """
logger.debug("Building File menu")
self.file_menu.add_command(
label="Load full config...", underline=0, command=self.config.load)
self.file_menu.add_command(
label="Save full config...", underline=0, command=self.config.save)
self.file_menu.add_separator()
self.file_menu.add_cascade(label="Open recent", underline=6, menu=self.recent_menu)
self.file_menu.add_separator()
self.file_menu.add_command(
label="Reset all to default", underline=0, command=self.config.cli_opts.reset)
self.file_menu.add_command(
label="Clear all", underline=0, command=self.config.cli_opts.clear)
self.file_menu.add_separator()
self.file_menu.add_command(label="Quit", underline=0, command=self.root.close_app)
self.add_cascade(label="File", menu=self.file_menu, underline=0)
logger.debug("Built File menu")
def build_recent_menu(self):
""" Load recent files into menu bar """
logger.debug("Building Recent Files menu")
serializer = JSONSerializer
menu_file = os.path.join(self.config.pathcache, ".recent.json")
if not os.path.isfile(menu_file):
self.clear_recent_files(serializer, menu_file)
with open(menu_file, "rb") as inp:
recent_files = serializer.unmarshal(inp.read().decode("utf-8"))
logger.debug("Loaded recent files: %s", recent_files)
for recent_item in recent_files:
filename, command = recent_item
logger.debug("processing: ('%s', %s)", filename, command)
if not os.path.isfile(filename):
logger.debug("File does not exist")
continue
lbl_command = command if command else "All"
self.recent_menu.add_command(
label="{} ({})".format(filename, lbl_command.title()),
command=lambda fnm=filename, cmd=command: self.config.load(cmd, fnm))
self.recent_menu.add_separator()
self.recent_menu.add_command(
label="Clear recent files",
underline=0,
command=lambda srl=serializer, mnu=menu_file: self.clear_recent_files(srl, mnu))
logger.debug("Built Recent Files menu")
@staticmethod
def clear_recent_files(serializer, menu_file):
""" Creates or clears recent file list """
logger.debug("clearing recent files list: '%s'", menu_file)
recent_files = serializer.marshal(list())
with open(menu_file, "wb") as out:
out.write(recent_files.encode("utf-8"))
def refresh_recent_menu(self):
""" Refresh recent menu on save/load of files """
self.recent_menu.delete(0, "end")
self.build_recent_menu()
def build_edit_menu(self):
""" Add the edit menu to the menu bar """
logger.debug("Building Edit menu")
edit_menu = tk.Menu(self, tearoff=0)
configs = self.scan_for_configs()
for name in sorted(list(configs.keys())):
label = "Configure {} Plugins...".format(name.title())
config = configs[name]
edit_menu.add_command(
label=label,
underline=10,
command=lambda conf=(name, config), root=self.root: popup_config(conf, root))
self.add_cascade(label="Edit", menu=edit_menu, underline=0)
logger.debug("Built Edit menu")
def scan_for_configs(self):
""" Scan for config.ini file locations """
root_path = os.path.abspath(os.path.dirname(sys.argv[0]))
plugins_path = os.path.join(root_path, "plugins")
logger.debug("Scanning path: '%s'", plugins_path)
configs = dict()
for dirpath, _, filenames in os.walk(plugins_path):
if "_config.py" in filenames:
plugin_type = os.path.split(dirpath)[-1]
config = self.load_config(plugin_type)
configs[plugin_type] = config
logger.debug("Configs loaded: %s", sorted(list(configs.keys())))
return configs
@staticmethod
def load_config(plugin_type):
""" Load the config to generate config file if it doesn't exist and get filename """
# Load config to generate default if doesn't exist
mod = ".".join(("plugins", plugin_type, "_config"))
module = import_module(mod)
config = module.Config(None)
logger.debug("Found '%s' config at '%s'", plugin_type, config.configfile)
return config

View file

@ -1,14 +1,13 @@
#!/usr/bin python3 #!/usr/bin python3
""" Cli Options and Config functions for the GUI """ """ Cli Options for the GUI """
import inspect import inspect
from argparse import SUPPRESS from argparse import SUPPRESS
import logging import logging
from tkinter import ttk from tkinter import ttk
from lib import cli from lib import cli
from lib.Serializer import JSONSerializer
import tools.cli as ToolsCli import tools.cli as ToolsCli
from .utils import FileHandler, Images from .utils import get_images
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -93,8 +92,7 @@ class CliOptions():
logger.trace("Skipping suppressed option: %s", opt) logger.trace("Skipping suppressed option: %s", opt)
continue continue
ctl, sysbrowser, filetypes, action_option = self.set_control(opt) ctl, sysbrowser, filetypes, action_option = self.set_control(opt)
opt["control_title"] = self.set_control_title( opt["control_title"] = self.set_control_title(opt.get("opts", ""))
opt.get("opts", ""))
opt["control"] = ctl opt["control"] = ctl
opt["filesystem_browser"] = sysbrowser opt["filesystem_browser"] = sysbrowser
opt["filetypes"] = filetypes opt["filetypes"] = filetypes
@ -126,6 +124,8 @@ class CliOptions():
sysbrowser, filetypes = self.set_sysbrowser(action, sysbrowser, filetypes = self.set_sysbrowser(action,
filetypes, filetypes,
action_option) action_option)
elif option.get("min_max", None):
ctl = ttk.Scale
elif option.get("choices", "") != "": elif option.get("choices", "") != "":
ctl = ttk.Combobox ctl = ttk.Combobox
elif option.get("action", "") == "store_true": elif option.get("action", "") == "store_true":
@ -226,7 +226,7 @@ class CliOptions():
optval = str(option.get("value", "").get()) optval = str(option.get("value", "").get())
opt = option["opts"][0] opt = option["opts"][0]
if command in ("extract", "convert") and opt == "-o": if command in ("extract", "convert") and opt == "-o":
Images().pathoutput = optval get_images().pathoutput = optval
if optval in ("False", ""): if optval in ("False", ""):
continue continue
elif optval == "True": elif optval == "True":
@ -238,59 +238,3 @@ class CliOptions():
else: else:
opt = (opt, optval) opt = (opt, optval)
yield opt yield opt
class Config():
""" Actions for loading and saving Faceswap GUI command configurations """
def __init__(self, cli_opts, tk_vars):
logger.debug("Initializing %s", self.__class__.__name__)
self.cli_opts = cli_opts
self.serializer = JSONSerializer
self.tk_vars = tk_vars
logger.debug("Initialized %s", self.__class__.__name__)
def load(self, command=None):
""" Load a saved config file """
logger.debug("Loading config: (command: '%s')", command)
cfgfile = FileHandler("open", "config").retfile
if not cfgfile:
return
cfg = self.serializer.unmarshal(cfgfile.read())
opts = self.get_command_options(cfg, command) if command else cfg
for cmd, opts in opts.items():
self.set_command_args(cmd, opts)
logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile)
def get_command_options(self, cfg, command):
""" return the saved options for the requested
command, if not loading global options """
opts = cfg.get(command, None)
if not opts:
self.tk_vars["consoleclear"].set(True)
print("No {} section found in file".format(command))
logger.info("No %s section found in file", command)
retval = {command: opts}
logger.debug(retval)
return retval
def set_command_args(self, command, options):
""" Pass the saved config items back to the CliOptions """
if not options:
return
for srcopt, srcval in options.items():
optvar = self.cli_opts.get_one_option_variable(command, srcopt)
if not optvar:
continue
optvar.set(srcval)
def save(self, command=None):
""" Save the current GUI state to a config file in json format """
logger.debug("Saving config: (command: '%s')", command)
cfgfile = FileHandler("save", "config").retfile
if not cfgfile:
return
cfg = self.cli_opts.get_option_values(command)
cfgfile.write(self.serializer.marshal(cfg))
cfgfile.close()
logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile)

348
lib/gui/popup_configure.py Normal file
View file

@ -0,0 +1,348 @@
#!/usr/bin python3
""" Configure Plugins popup of the Faceswap GUI """
from configparser import ConfigParser
import logging
import tkinter as tk
from tkinter import ttk
from .tooltip import Tooltip
from .utils import get_config, ContextMenu, set_slider_rounding
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
POPUP = dict()
def popup_config(config, root):
""" Close any open popup and open requested popup """
if POPUP:
p_key = list(POPUP.keys())[0]
logger.debug("Closing open popup: '%s'", p_key)
POPUP[p_key].destroy()
del POPUP[p_key]
window = ConfigurePlugins(config, root)
POPUP[config[0]] = window
class ConfigurePlugins(tk.Toplevel):
""" Pop up for detailed graph/stats for selected session """
def __init__(self, config, root):
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__()
name, self.config = config
self.title("{} Plugins".format(name.title()))
self.set_geometry(root)
self.page_frame = ttk.Frame(self)
self.page_frame.pack(fill=tk.BOTH, expand=True)
self.plugin_info = dict()
self.config_dict_gui = self.get_config()
self.build()
self.update()
logger.debug("Initialized %s", self.__class__.__name__)
def set_geometry(self, root):
""" Set pop-up geometry """
scaling_factor = get_config().scaling_factor
pos_x = root.winfo_x() + 80
pos_y = root.winfo_y() + 80
width = int(720 * scaling_factor)
height = int(400 * scaling_factor)
logger.debug("Pop up Geometry: %sx%s, %s+%s", width, height, pos_x, pos_y)
self.geometry("{}x{}+{}+{}".format(width, height, pos_x, pos_y))
def get_config(self):
""" Format config into useful format for GUI and pull default value if a value has not
been supplied """
logger.debug("Formatting Config for GUI")
conf = dict()
for section in self.config.config.sections():
self.config.section = section
category = section.split(".")[0]
options = self.config.defaults[section]
conf.setdefault(category, dict())[section] = options
for key in options.keys():
if key == "helptext":
self.plugin_info[section] = options[key]
continue
options[key]["value"] = self.config.config_dict.get(key, options[key]["default"])
logger.debug("Formatted Config for GUI: %s", conf)
return conf
def build(self):
""" Build the config popup """
logger.debug("Building plugin config popup")
container = ttk.Notebook(self.page_frame)
container.pack(fill=tk.BOTH, expand=True)
categories = sorted(list(key for key in self.config_dict_gui.keys()))
if "global" in categories: # Move global to first item
categories.insert(0, categories.pop(categories.index("global")))
for category in categories:
page = self.build_page(container, category)
container.add(page, text=category.title())
self.add_frame_separator()
self.add_actions()
logger.debug("Built plugin config popup")
def build_page(self, container, category):
""" Build a plugin config page """
logger.debug("Building plugin config page: '%s'", category)
plugins = sorted(list(key for key in self.config_dict_gui[category].keys()))
if any(plugin != category for plugin in plugins):
page = ttk.Notebook(container)
page.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
for plugin in plugins:
frame = ConfigFrame(page,
self.config_dict_gui[category][plugin],
self.plugin_info[plugin])
title = plugin[plugin.rfind(".") + 1:]
title = title.replace("_", " ").title()
page.add(frame, text=title)
else:
page = ConfigFrame(container,
self.config_dict_gui[category][plugins[0]],
self.plugin_info[plugins[0]])
logger.debug("Built plugin config page: '%s'", category)
return page
def add_frame_separator(self):
""" Add a separator between top and bottom frames """
logger.debug("Add frame seperator")
sep = ttk.Frame(self.page_frame, height=2, relief=tk.RIDGE)
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
logger.debug("Added frame seperator")
def add_actions(self):
""" Add Action buttons """
logger.debug("Add action buttons")
frame = ttk.Frame(self.page_frame)
frame.pack(fill=tk.BOTH, padx=5, pady=5, side=tk.BOTTOM)
btn_cls = ttk.Button(frame, text="Cancel", width=10, command=self.destroy)
btn_cls.pack(padx=2, side=tk.RIGHT)
btn_ok = ttk.Button(frame, text="OK", width=10, command=self.save_config)
btn_ok.pack(padx=2, side=tk.RIGHT)
logger.debug("Added action buttons")
def save_config(self):
""" Save the config file """
logger.debug("Saving config")
options = {sect: opts
for value in self.config_dict_gui.values()
for sect, opts in value.items()}
new_config = ConfigParser(allow_no_value=True)
for section, items in self.config.defaults.items():
logger.debug("Adding section: '%s')", section)
self.config.insert_config_section(section, items["helptext"], config=new_config)
for item, def_opt in items.items():
if item == "helptext":
continue
new_opt = options[section][item]
logger.debug("Adding option: (item: '%s', default: '%s' new: '%s'",
item, def_opt, new_opt)
helptext = def_opt["helptext"]
helptext += self.config.set_helptext_choices(def_opt)
helptext += "\n[Default: {}]".format(def_opt["default"])
helptext = self.config.format_help(helptext, is_section=False)
new_config.set(section, helptext)
new_config.set(section, item, str(new_opt["selected"].get()))
self.config.config = new_config
self.config.save_config()
print("Saved config: '{}'".format(self.config.configfile))
self.destroy()
logger.debug("Saved config")
class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors
""" Config Frame - Holds the Options for config """
def __init__(self, parent, options, plugin_info):
logger.debug("Initializing %s", self.__class__.__name__)
ttk.Frame.__init__(self, parent)
self.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
self.options = options
self.plugin_info = plugin_info
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
self.optsframe = ttk.Frame(self.canvas)
self.optscanvas = self.canvas.create_window((0, 0), window=self.optsframe, anchor=tk.NW)
self.build_frame()
logger.debug("Initialized %s", self.__class__.__name__)
def build_frame(self):
""" Build the options frame for this command """
logger.debug("Add Config Frame")
self.add_scrollbar()
self.canvas.bind("<Configure>", self.resize_frame)
self.add_info()
for key, val in self.options.items():
if key == "helptext":
continue
OptionControl(key, val, self.optsframe)
logger.debug("Added Config Frame")
def add_scrollbar(self):
""" Add a scrollbar to the options frame """
logger.debug("Add Config Scrollbar")
scrollbar = ttk.Scrollbar(self, command=self.canvas.yview)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.canvas.config(yscrollcommand=scrollbar.set)
self.optsframe.bind("<Configure>", self.update_scrollbar)
logger.debug("Added Config Scrollbar")
def update_scrollbar(self, event): # pylint: disable=unused-argument
""" Update the options frame scrollbar """
self.canvas.configure(scrollregion=self.canvas.bbox("all"))
def resize_frame(self, event):
""" Resize the options frame to fit the canvas """
logger.debug("Resize Config Frame")
canvas_width = event.width
self.canvas.itemconfig(self.optscanvas, width=canvas_width)
logger.debug("Resized Config Frame")
def add_info(self):
""" Plugin information """
info_frame = ttk.Frame(self.optsframe)
info_frame.pack(fill=tk.X, expand=True)
lbl = ttk.Label(info_frame, text="About:", width=20, anchor=tk.W)
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
info = ttk.Label(info_frame, text=self.plugin_info)
info.pack(padx=5, pady=5, fill=tk.X, expand=True)
class OptionControl():
""" Build the correct control for the option parsed and place it on the
frame """
def __init__(self, title, values, option_frame):
logger.debug("Initializing %s", self.__class__.__name__)
self.title = title
self.values = values
self.option_frame = option_frame
self.control = self.set_control()
self.control_frame = self.set_control_frame()
self.tk_var = self.set_tk_var()
self.build_full_control()
logger.debug("Initialized %s", self.__class__.__name__)
@property
def helptext(self):
""" Format the help text for tooltips """
logger.debug("Format control help: '%s'", self.title)
helptext = self.values.get("helptext", "")
helptext = helptext.replace("\n\t", "\n - ").replace("%%", "%")
helptext = self.title + " - " + helptext
logger.debug("Formatted control help: (title: '%s', help: '%s'", self.title, helptext)
return helptext
def set_control(self):
""" Set the correct control type for this option """
dtype = self.values["type"]
choices = self.values["choices"]
if choices:
control = ttk.Combobox
elif dtype == bool:
control = ttk.Checkbutton
elif dtype in (int, float):
control = ttk.Scale
else:
control = ttk.Entry
logger.debug("Setting control '%s' to %s", self.title, control)
return control
def set_control_frame(self):
""" Frame to hold control and it's label """
logger.debug("Build config control frame")
frame = ttk.Frame(self.option_frame)
frame.pack(fill=tk.X, expand=True)
logger.debug("Built confog control frame")
return frame
def set_tk_var(self):
""" Correct variable type for control """
logger.debug("Setting config variable type: '%s'", self.title)
var = tk.BooleanVar if self.control == ttk.Checkbutton else tk.StringVar
var = var(self.control_frame)
logger.debug("Set config variable type: ('%s': %s", self.title, type(var))
return var
def build_full_control(self):
""" Build the correct control type for the option passed through """
logger.debug("Build confog option control")
self.build_control_label()
self.build_one_control()
self.values["selected"] = self.tk_var
logger.debug("Built option control")
def build_control_label(self):
""" Label for control """
logger.debug("Build config control label: '%s'", self.title)
title = self.title.replace("_", " ").title()
lbl = ttk.Label(self.control_frame, text=title, width=20, anchor=tk.W)
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
logger.debug("Built config control label: '%s'", self.title)
def build_one_control(self):
""" Build and place the option controls """
logger.debug("Build control: (title: '%s', values: %s)", self.title, self.values)
self.tk_var.set(self.values["value"])
if self.control == ttk.Scale:
self.slider_control()
else:
self.control_to_optionsframe()
logger.debug("Built control: '%s'", self.title)
def slider_control(self):
""" A slider control with corresponding Entry box """
logger.debug("Add slider control to Config Options Frame: %s", self.control)
d_type = self.values["type"]
rnd = self.values["rounding"]
min_max = self.values["min_max"]
tbox = ttk.Entry(self.control_frame, width=8, textvariable=self.tk_var, justify=tk.RIGHT)
tbox.pack(padx=(0, 5), side=tk.RIGHT)
ctl = self.control(
self.control_frame,
variable=self.tk_var,
command=lambda val, var=self.tk_var, dt=d_type, rn=rnd, mm=min_max:
set_slider_rounding(val, var, dt, rn, mm))
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
rc_menu = ContextMenu(ctl)
rc_menu.cm_bind()
ctl["from_"] = min_max[0]
ctl["to"] = min_max[1]
Tooltip(ctl, text=self.helptext, wraplength=720)
Tooltip(tbox, text=self.helptext, wraplength=720)
logger.debug("Added slider control to Options Frame: %s", self.control)
def control_to_optionsframe(self):
""" Standard non-check buttons sit in the main options frame """
logger.debug("Add control to Options Frame: %s", self.control)
choices = self.values["choices"]
if self.control == ttk.Checkbutton:
ctl = self.control(self.control_frame, variable=self.tk_var, text=None)
else:
ctl = self.control(self.control_frame, textvariable=self.tk_var)
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
rc_menu = ContextMenu(ctl)
rc_menu.cm_bind()
if choices:
logger.debug("Adding combo choices: %s", choices)
ctl["values"] = [choice for choice in choices]
Tooltip(ctl, text=self.helptext, wraplength=720)
logger.debug("Added control to Options Frame: %s", self.control)

View file

@ -9,14 +9,14 @@ import warnings
from math import ceil, sqrt from math import ceil, sqrt
import numpy as np import numpy as np
import tensorflow as tf
from lib.Serializer import PickleSerializer from lib.Serializer import JSONSerializer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def convert_time(timestamp): def convert_time(timestamp):
""" Convert time stamp to total hours, mins and second """ """ Convert time stamp to total hours, minutes and seconds """
hrs = int(timestamp // 3600) hrs = int(timestamp // 3600)
if hrs < 10: if hrs < 10:
hrs = "{0:02d}".format(hrs) hrs = "{0:02d}".format(hrs)
@ -25,164 +25,279 @@ def convert_time(timestamp):
return hrs, mins, secs return hrs, mins, secs
class SavedSessions(): class TensorBoardLogs():
""" Saved Training Session """ """ Parse and return data from TensorBoard logs """
def __init__(self, sessions_data): def __init__(self, logs_folder):
self.serializer = PickleSerializer self.folder_base = logs_folder
self.sessions = self.load_sessions(sessions_data) self.log_filenames = self.set_log_filenames()
def load_sessions(self, filename): def set_log_filenames(self):
""" Load previously saved sessions """ """ Set the TensorBoard log filenames for all existing sessions """
stats = list() logger.debug("Loading log filenames. base_dir: '%s'", self.folder_base)
if os.path.isfile(filename): log_filenames = dict()
with open(filename, self.serializer.roptions) as sessions: for dirpath, _, filenames in os.walk(self.folder_base):
stats = self.serializer.unmarshal(sessions.read()) if not any(filename.startswith("events.out.tfevents") for filename in filenames):
return stats continue
logfiles = [filename for filename in filenames
if filename.startswith("events.out.tfevents")]
# Take the last logfile, in case of previous crash
logfile = os.path.join(dirpath, sorted(logfiles)[-1])
side, session = os.path.split(dirpath)
side = os.path.split(side)[1]
session = int(session[session.rfind("_") + 1:])
log_filenames.setdefault(session, dict())[side] = logfile
logger.debug("logfiles: %s", log_filenames)
return log_filenames
def save_sessions(self, filename): def get_loss(self, side=None, session=None):
""" Save the session file """ """ Read the loss from the TensorBoard logs
with open(filename, self.serializer.woptions) as session: Specify a side or a session or leave at None for all
session.write(self.serializer.marshal(self.sessions)) """
logger.info("Saved session stats to: %s", filename) logger.debug("Getting loss: (side: %s, session: %s)", side, session)
all_loss = dict()
for sess, sides in self.log_filenames.items():
if session is not None and sess != session:
logger.debug("Skipping session: %s", sess)
continue
loss = dict()
for sde, logfile in sides.items():
if side is not None and sde != side:
logger.debug("Skipping side: %s", sde)
continue
for event in tf.train.summary_iterator(logfile):
for summary in event.summary.value:
if "loss" not in summary.tag:
continue
tag = summary.tag.replace("batch_", "")
loss.setdefault(tag,
dict()).setdefault(sde,
list()).append(summary.simple_value)
all_loss[sess] = loss
return all_loss
def get_timestamps(self, session=None):
""" Read the timestamps from the TensorBoard logs
Specify a session or leave at None for all
NB: For all intents and purposes timestamps are the same for
both sides, so just read from one side """
logger.debug("Getting timestamps")
all_timestamps = dict()
for sess, sides in self.log_filenames.items():
if session is not None and sess != session:
logger.debug("Skipping sessions: %s", sess)
continue
for logfile in sides.values():
timestamps = [event.wall_time
for event in tf.train.summary_iterator(logfile)]
logger.debug("Total timestamps for session %s: %s", sess, len(timestamps))
all_timestamps[sess] = timestamps
break # break after first file read
return all_timestamps
class CurrentSession(): class Session():
""" The current training session """ """ The Loaded or current training session """
def __init__(self): def __init__(self, model_dir=None, model_name=None):
self.stats = {"iterations": 0, logger.debug("Initializing %s", self.__class__.__name__)
"batchsize": None, # Set and reset by wrapper self.serializer = JSONSerializer
"timestamps": [], self.state = None
"loss": [], self.modeldir = model_dir # Set and reset by wrapper for training sessions
"losskeys": []} self.modelname = model_name # Set and reset by wrapper for training sessions
self.timestats = {"start": None, self.tb_logs = None
"elapsed": None} self.initialized = False
self.modeldir = None # Set and reset by wrapper self.session_id = None # Set to specific session_id or current training session
self.filename = None self.summary = SessionsSummary(self)
self.historical = None logger.debug("Initialized %s", self.__class__.__name__)
def initialise_session(self, currentloss): @property
""" Initialise the training session """ def batchsize(self):
self.load_historical() """ Return the session batchsize """
for item in currentloss: return self.session["batchsize"]
self.stats["losskeys"].append(item[0])
self.stats["loss"].append(list())
self.timestats["start"] = time.time()
def load_historical(self): @property
""" Load historical data and add current session to the end """ def config(self):
self.filename = os.path.join(self.modeldir, "trainingstats.fss") """ Return config and other information """
self.historical = SavedSessions(self.filename) retval = {key: val for key, val in self.state["config"]}
self.historical.sessions.append(self.stats) retval["training_size"] = self.state["training_size"]
retval["input_size"] = [val[0] for key, val in self.state["inputs"].items()
if key.startswith("face")][0]
return retval
def add_loss(self, currentloss): @property
""" Add a loss item from the training process """ def full_summary(self):
if self.stats["iterations"] == 0: """ Retun all sessions summary data"""
self.initialise_session(currentloss) return self.summary.compile_stats()
self.stats["iterations"] += 1 @property
self.add_timestats() def iterations(self):
""" Return session iterations """
return self.session["iterations"]
for idx, item in enumerate(currentloss): @property
self.stats["loss"][idx].append(float(item[1])) def logging_disabled(self):
""" Return whether logging is disabled for this session """
return self.session["no_logs"]
def add_timestats(self): @property
""" Add timestats to loss dict and timestats """ def loss(self):
now = time.time() """ Return loss from logs for current session """
self.stats["timestamps"].append(now) loss_dict = self.tb_logs.get_loss(session=self.session_id)[self.session_id]
elapsed_time = now - self.timestats["start"] return loss_dict
hrs, mins, secs = convert_time(elapsed_time)
self.timestats["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
def save_session(self): @property
""" Save the session file to the modeldir """ def loss_keys(self):
if self.stats["iterations"] > 0: """ Return list of unique session loss keys """
logger.info("Saving session stats...") if self.session_id is None:
self.historical.save_sessions(self.filename) loss_keys = self.total_loss_keys
else:
loss_keys = set(loss_key for side_keys in self.session["loss_names"].values()
for loss_key in side_keys)
return list(loss_keys)
@property
def lowest_loss(self):
""" Return the lowest average loss per save iteration seen """
return self.state["lowest_avg_loss"]
class SessionsTotals(): @property
""" The compiled totals of all saved sessions """ def session(self):
def __init__(self, all_sessions): """ Return current session dictionary """
self.stats = {"split": [], return self.state["sessions"][str(self.session_id)]
"iterations": 0,
"batchsize": [],
"timestamps": [],
"loss": [],
"losskeys": []}
self.initiate(all_sessions) @property
self.compile(all_sessions) def session_ids(self):
""" Return sorted list of all existing session ids in the state file """
return sorted([int(key) for key in self.state["sessions"].keys()])
def initiate(self, sessions): @property
""" Initiate correct losskey titles and number of loss lists """ def timestamps(self):
for losskey in sessions[0]["losskeys"]: """ Return timestamps from logs for current session """
self.stats["losskeys"].append(losskey) ts_dict = self.tb_logs.get_timestamps(session=self.session_id)
self.stats["loss"].append(list()) return ts_dict[self.session_id]
def compile(self, sessions): @property
""" Compile all of the sessions into totals """ def total_batchsize(self):
current_split = 0 """ Return all session batch sizes """
for session in sessions: return {int(sess_id): sess["batchsize"]
iterations = session["iterations"] for sess_id, sess in self.state["sessions"].items()}
current_split += iterations
self.stats["split"].append(current_split)
self.stats["iterations"] += iterations
self.stats["timestamps"].extend(session["timestamps"])
self.stats["batchsize"].append(session["batchsize"])
self.add_loss(session["loss"])
def add_loss(self, session_loss): @property
""" Add loss vals to each of their respective lists """ def total_iterations(self):
for idx, loss in enumerate(session_loss): """ Return session iterations """
self.stats["loss"][idx].extend(loss) return self.state["iterations"]
@property
def total_loss(self):
""" Return collated loss for all session """
loss_dict = dict()
for sess in self.tb_logs.get_loss().values():
for loss_key, side_loss in sess.items():
for side, loss in side_loss.items():
loss_dict.setdefault(loss_key, dict()).setdefault(side, list()).extend(loss)
return loss_dict
@property
def total_loss_keys(self):
""" Return list of unique session loss keys across all sessions """
loss_keys = set(loss_key
for session in self.state["sessions"].values()
for loss_keys in session["loss_names"].values()
for loss_key in loss_keys)
return list(loss_keys)
@property
def total_timestamps(self):
""" Return timestamps from logs seperated per session for all sessions """
return self.tb_logs.get_timestamps()
def initialize_session(self, is_training=False, session_id=None):
""" Initialize the training session """
logger.debug("Initializing session: (is_training: %s, session_id: %s)",
is_training, session_id)
self.load_state_file()
self.tb_logs = TensorBoardLogs(os.path.join(self.modeldir,
"{}_logs".format(self.modelname)))
if is_training:
self.session_id = max(int(key) for key in self.state["sessions"].keys())
else:
self.session_id = session_id
self.initialized = True
logger.debug("Initialized session")
def load_state_file(self):
""" Load the current state file """
state_file = os.path.join(self.modeldir, "{}_state.json".format(self.modelname))
logger.debug("Loading State: '%s'", state_file)
try:
with open(state_file, "rb") as inp:
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
self.state = state
logger.debug("Loaded state: %s", state)
except IOError as err:
logger.warning("Unable to load state file. Graphing disabled: %s", str(err))
class SessionsSummary(): class SessionsSummary():
""" Calculations for analysis summary stats """ """ Calculations for analysis summary stats """
def __init__(self, raw_data): def __init__(self, session):
self.summary = list() logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session)
self.summary_stats_compile(raw_data) self.session = session
logger.debug("Initialized %s", self.__class__.__name__)
def summary_stats_compile(self, raw_data): @property
""" Compile summary stats """ def iterations(self):
raw_summaries = list() """ Return session iterations sizes """
for idx, session in enumerate(raw_data): return {int(sess_id): sess["iterations"]
raw_summaries.append(self.summarise_session(idx, session)) for sess_id, sess in self.session.state["sessions"].items()}
totals_summary = self.summarise_totals(raw_summaries) @property
raw_summaries.append(totals_summary) def time_stats(self):
self.format_summaries(raw_summaries) """ Return session time stats """
ts_data = self.session.tb_logs.get_timestamps()
time_stats = {sess_id: {"start_time": min(timestamps),
"end_time": max(timestamps)}
for sess_id, timestamps in ts_data.items()}
return time_stats
# Compile Session Summaries @property
@staticmethod def sessions_stats(self):
def summarise_session(idx, session): """ Return compiled stats """
""" Compile stats for session passed in """ compiled = list()
starttime = session["timestamps"][0] for sess_idx, ts_data in self.time_stats.items():
endtime = session["timestamps"][-1] elapsed = ts_data["end_time"] - ts_data["start_time"]
elapsed = endtime - starttime batchsize = self.session.total_batchsize[sess_idx]
# Bump elapsed to 0.1s if no time is recorded iterations = self.iterations[sess_idx]
# to hack around div by zero error compiled.append({"session": sess_idx,
elapsed = 0.1 if elapsed == 0 else elapsed "start": ts_data["start_time"],
rate = (session["batchsize"] * session["iterations"]) / elapsed "end": ts_data["end_time"],
return {"session": idx + 1, "elapsed": elapsed,
"start": starttime, "rate": (batchsize * iterations) / elapsed,
"end": endtime, "batch": batchsize,
"elapsed": elapsed, "iterations": iterations})
"rate": rate, return compiled
"batch": session["batchsize"],
"iterations": session["iterations"]} def compile_stats(self):
""" Compile sessions stats with totals, format and return """
logger.debug("Compiling sessions summary data")
compiled_stats = self.sessions_stats
logger.debug("sessions_stats: %s", compiled_stats)
total_stats = self.total_stats(compiled_stats)
compiled_stats.append(total_stats)
compiled_stats = self.format_stats(compiled_stats)
logger.debug("Final stats: %s", compiled_stats)
return compiled_stats
@staticmethod @staticmethod
def summarise_totals(raw_summaries): def total_stats(sessions_stats):
""" Compile the stats for all sessions combined """ """ Return total stats """
logger.debug("Compiling Totals")
elapsed = 0 elapsed = 0
rate = 0 rate = 0
batchset = set() batchset = set()
iterations = 0 iterations = 0
total_summaries = len(raw_summaries) total_summaries = len(sessions_stats)
for idx, summary in enumerate(sessions_stats):
for idx, summary in enumerate(raw_summaries):
if idx == 0: if idx == 0:
starttime = summary["start"] starttime = summary["start"]
if idx == total_summaries - 1: if idx == total_summaries - 1:
@ -192,150 +307,170 @@ class SessionsSummary():
batchset.add(summary["batch"]) batchset.add(summary["batch"])
iterations += summary["iterations"] iterations += summary["iterations"]
batch = ",".join(str(bs) for bs in batchset) batch = ",".join(str(bs) for bs in batchset)
totals = {"session": "Total",
"start": starttime,
"end": endtime,
"elapsed": elapsed,
"rate": rate / total_summaries,
"batch": batch,
"iterations": iterations}
logger.debug(totals)
return totals
return {"session": "Total", @staticmethod
"start": starttime, def format_stats(compiled_stats):
"end": endtime, """ Format for display """
"elapsed": elapsed, logger.debug("Formatting stats")
"rate": rate / total_summaries, for summary in compiled_stats:
"batch": batch,
"iterations": iterations}
def format_summaries(self, raw_summaries):
""" Format the summaries nicely for display """
for summary in raw_summaries:
summary["start"] = time.strftime("%x %X",
time.gmtime(summary["start"]))
summary["end"] = time.strftime("%x %X",
time.gmtime(summary["end"]))
hrs, mins, secs = convert_time(summary["elapsed"]) hrs, mins, secs = convert_time(summary["elapsed"])
summary["start"] = time.strftime("%x %X", time.gmtime(summary["start"]))
summary["end"] = time.strftime("%x %X", time.gmtime(summary["end"]))
summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs) summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
summary["rate"] = "{0:.1f}".format(summary["rate"]) summary["rate"] = "{0:.1f}".format(summary["rate"])
self.summary = raw_summaries return compiled_stats
class Calculations(): class Calculations():
""" Class to hold calculations against raw session data """ """ Class to pull raw data for given session(s) and perform calculations """
def __init__(self, def __init__(self, session, display="loss", loss_keys=["loss"], selections=["raw"],
session, avg_samples=10, flatten_outliers=False, is_totals=False):
display="loss", logger.debug("Initializing %s: (session: %s, display: %s, loss_keys: %s, selections: %s, "
selections=["raw"], "avg_samples: %s, flatten_outliers: %s, is_totals: %s",
avg_samples=10, self.__class__.__name__, session, display, loss_keys, selections, avg_samples,
flatten_outliers=False, flatten_outliers, is_totals)
is_totals=False):
warnings.simplefilter("ignore", np.RankWarning) warnings.simplefilter("ignore", np.RankWarning)
self.session = session self.session = session
if display.lower() == "loss": self.display = display
display = self.session["losskeys"] self.loss_keys = loss_keys
else: self.selections = selections
display = [display] self.is_totals = is_totals
self.args = {"display": display, self.args = {"avg_samples": int(avg_samples),
"selections": selections, "flatten_outliers": flatten_outliers}
"avg_samples": int(avg_samples),
"flatten_outliers": flatten_outliers,
"is_totals": is_totals}
self.iterations = 0 self.iterations = 0
self.stats = None self.stats = None
self.refresh() self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)
def refresh(self): def refresh(self):
""" Refresh the stats """ """ Refresh the stats """
logger.debug("Refreshing")
if not self.session.initialized:
logger.warning("Session data is not initialized. Not refreshing")
return
self.iterations = 0 self.iterations = 0
self.stats = self.get_raw() self.stats = self.get_raw()
self.get_calculations() self.get_calculations()
self.remove_raw() self.remove_raw()
logger.debug("Refreshed")
def get_raw(self): def get_raw(self):
""" Add raw data to stats dict """ """ Add raw data to stats dict """
raw = dict() logger.debug("Getting Raw Data")
for idx, item in enumerate(self.args["display"]):
if item.lower() == "rate":
data = self.calc_rate(self.session)
else:
data = self.session["loss"][idx][:]
raw = dict()
iterations = set()
if self.display.lower() == "loss":
loss_dict = self.session.total_loss if self.is_totals else self.session.loss
for loss_name, side_loss in loss_dict.items():
if loss_name not in self.loss_keys:
continue
for side, loss in side_loss.items():
if self.args["flatten_outliers"]:
loss = self.flatten_outliers(loss)
iterations.add(len(loss))
raw["raw_{}_{}".format(loss_name, side)] = loss
self.iterations = 0 if not iterations else min(iterations)
if len(iterations) > 1:
# Crop all losses to the same number of items
if self.iterations == 0:
raw = {lossname: list() for lossname in raw.keys()}
else:
raw = {lossname: loss[:self.iterations] for lossname, loss in raw}
else: # Rate calulation
data = self.calc_rate_total() if self.is_totals else self.calc_rate()
if self.args["flatten_outliers"]: if self.args["flatten_outliers"]:
data = self.flatten_outliers(data) data = self.flatten_outliers(data)
self.iterations = len(data)
raw = {"raw_rate": data}
if self.iterations == 0: logger.debug("Got Raw Data")
self.iterations = len(data)
raw["raw_{}".format(item)] = data
return raw return raw
def remove_raw(self): def remove_raw(self):
""" Remove raw values from stats if not requested """ """ Remove raw values from stats if not requested """
if "raw" in self.args["selections"]: if "raw" in self.selections:
return return
logger.debug("Removing Raw Data from output")
for key in list(self.stats.keys()): for key in list(self.stats.keys()):
if key.startswith("raw"): if key.startswith("raw"):
del self.stats[key] del self.stats[key]
logger.debug("Removed Raw Data from output")
def calc_rate(self, data): def calc_rate(self):
""" Calculate rate per iteration """
logger.debug("Calculating rate")
batchsize = self.session.batchsize
timestamps = self.session.timestamps
iterations = range(len(timestamps) - 1)
rate = [batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations]
logger.debug("Calculated rate: Item_count: %s", len(rate))
return rate
def calc_rate_total(self):
""" Calculate rate per iteration """ Calculate rate per iteration
NB: For totals, gaps between sessions can be large NB: For totals, gaps between sessions can be large
so time diffeence has to be reset for each session's so time difference has to be reset for each session's
rate calculation """ rate calculation """
batchsize = data["batchsize"] logger.debug("Calculating totals rate")
if self.args["is_totals"]: batchsizes = self.session.total_batchsize
split = data["split"] total_timestamps = self.session.total_timestamps
else:
batchsize = [batchsize]
split = [len(data["timestamps"])]
prev_split = 0
rate = list() rate = list()
for sess_id in sorted(total_timestamps.keys()):
for idx, current_split in enumerate(split): batchsize = batchsizes[sess_id]
prev_time = data["timestamps"][prev_split] timestamps = total_timestamps[sess_id]
timestamp_chunk = data["timestamps"][prev_split:current_split] iterations = range(len(timestamps) - 1)
for item in timestamp_chunk: rate.extend([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations])
current_time = item logger.debug("Calculated totals rate: Item_count: %s", len(rate))
timediff = current_time - prev_time
iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff
rate.append(iter_rate)
prev_time = current_time
prev_split = current_split
if self.args["flatten_outliers"]:
rate = self.flatten_outliers(rate)
return rate return rate
@staticmethod @staticmethod
def flatten_outliers(data): def flatten_outliers(data):
""" Remove the outliers from a provided list """ """ Remove the outliers from a provided list """
logger.debug("Flattening outliers")
retdata = list() retdata = list()
samples = len(data) samples = len(data)
mean = (sum(data) / samples) mean = (sum(data) / samples)
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples) limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
logger.debug("samples: %s, mean: %s, limit: %s", samples, mean, limit)
for item in data: for idx, item in enumerate(data):
if (mean - limit) <= item <= (mean + limit): if (mean - limit) <= item <= (mean + limit):
retdata.append(item) retdata.append(item)
else: else:
logger.debug("Item idx: %s, value: %s flattened to %s", idx, item, mean)
retdata.append(mean) retdata.append(mean)
logger.debug("Flattened outliers")
return retdata return retdata
def get_calculations(self): def get_calculations(self):
""" Perform the required calculations """ """ Perform the required calculations """
for selection in self.get_selections(): for selection in self.selections:
if selection[0] == "raw": if selection == "raw":
continue continue
method = getattr(self, "calc_{}".format(selection[0])) logger.debug("Calculating: %s", selection)
key = "{}_{}".format(selection[0], selection[1]) method = getattr(self, "calc_{}".format(selection))
raw = self.stats["raw_{}".format(selection[1])] raw_keys = [key for key in self.stats.keys() if key.startswith("raw_")]
self.stats[key] = method(raw) for key in raw_keys:
selected_key = "{}_{}".format(selection, key.replace("raw_", ""))
def get_selections(self): self.stats[selected_key] = method(self.stats[key])
""" Compile a list of data to be calculated """
for summary in self.args["selections"]:
for item in self.args["display"]:
yield summary, item
def calc_avg(self, data): def calc_avg(self, data):
""" Calculate rolling average """ """ Calculate rolling average """
logger.debug("Calculating Average")
avgs = list() avgs = list()
presample = ceil(self.args["avg_samples"] / 2) presample = ceil(self.args["avg_samples"] / 2)
postsample = self.args["avg_samples"] - presample postsample = self.args["avg_samples"] - presample
@ -353,11 +488,13 @@ class Calculations():
avg = sum(data[idx - presample:idx + postsample]) \ avg = sum(data[idx - presample:idx + postsample]) \
/ self.args["avg_samples"] / self.args["avg_samples"]
avgs.append(avg) avgs.append(avg)
logger.debug("Calculated Average")
return avgs return avgs
@staticmethod @staticmethod
def calc_trend(data): def calc_trend(data):
""" Compile trend data """ """ Compile trend data """
logger.debug("Calculating Trend")
points = len(data) points = len(data)
if points < 10: if points < 10:
dummy = [None for i in range(points)] dummy = [None for i in range(points)]
@ -366,4 +503,5 @@ class Calculations():
fit = np.polyfit(x_range, data, 3) fit = np.polyfit(x_range, data, 3)
poly = np.poly1d(fit) poly = np.poly1d(fit)
trend = poly(x_range) trend = poly(x_range)
logger.debug("Calculated Trend")
return trend return trend

View file

@ -42,7 +42,7 @@ class Tooltip:
waittime=400, waittime=400,
wraplength=250): wraplength=250):
self.waittime = waittime # in miliseconds, originally 500 self.waittime = waittime # in milliseconds, originally 500
self.wraplength = wraplength # in pixels, originally 180 self.wraplength = wraplength # in pixels, originally 180
self.widget = widget self.widget = widget
self.text = text self.text = text
@ -115,7 +115,7 @@ class Tooltip:
# No further checks will be done. # No further checks will be done.
# TIP: # TIP:
# A further mod might automagically augment the # A further mod might auto-magically augment the
# wraplength when the tooltip is too high to be # wraplength when the tooltip is too high to be
# kept inside the screen. # kept inside the screen.
y_1 = 0 y_1 = 0

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Utility functions for the GUI """ """ Utility functions for the GUI """
import logging import logging
import os import os
import platform import platform
import sys import sys
@ -10,24 +9,50 @@ import tkinter as tk
from tkinter import filedialog, ttk from tkinter import filedialog, ttk
from PIL import Image, ImageTk from PIL import Image, ImageTk
from lib.Serializer import JSONSerializer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG = None
_IMAGES = None
class Singleton(type): def initialize_config(cli_opts, scaling_factor, pathcache, statusbar, session):
""" Instigate a singleton. """ Initialize the config and add to global constant """
From: https://stackoverflow.com/questions/6760685 global _CONFIG # pylint: disable=global-statement
if _CONFIG is not None:
return
logger.debug("Initializing config: (cli_opts: %s, tk_vars: %s, pathcache: %s, statusbar: %s, "
"session: %s)", cli_opts, scaling_factor, pathcache, statusbar, session)
_CONFIG = Config(cli_opts, scaling_factor, pathcache, statusbar, session)
Singletons are often frowned upon.
Feel free to instigate a better solution """
_instances = {} def get_config():
""" return the _CONFIG constant """
return _CONFIG
def __call__(cls, *args, **kwargs):
if cls not in cls._instances: def initialize_images():
cls._instances[cls] = super(Singleton, """ Initialize the config and add to global constant """
cls).__call__(*args, global _IMAGES # pylint: disable=global-statement
**kwargs) if _IMAGES is not None:
return cls._instances[cls] return
logger.debug("Initializing images")
_IMAGES = Images()
def get_images():
""" return the _CONFIG constant """
return _IMAGES
def set_slider_rounding(value, var, d_type, round_to, min_max):
""" Set the underlying variable to correct number based on slider rounding """
if d_type == float:
var.set(round(float(value), round_to))
else:
steps = range(min_max[0], min_max[1] + round_to, round_to)
value = min(steps, key=lambda x: abs(x - int(float(value))))
var.set(value)
class FileHandler(): class FileHandler():
@ -52,8 +77,8 @@ class FileHandler():
("PNG", "*.png"), ("PNG", "*.png"),
("TIFF", "*.tif", "*.tiff"), ("TIFF", "*.tif", "*.tiff"),
all_files), all_files),
"state": (("State files", "*.json"), all_files),
"log": (("Log files", "*.log"), all_files), "log": (("Log files", "*.log"), all_files),
"session": (("Faceswap session files", "*.fss"), all_files),
"video": (("Audio Video Interleave", "*.avi"), "video": (("Audio Video Interleave", "*.avi"),
("Flash Video", "*.flv"), ("Flash Video", "*.flv"),
("Matroska", "*.mkv"), ("Matroska", "*.mkv"),
@ -164,11 +189,15 @@ class FileHandler():
return return
class Images(metaclass=Singleton): class Images():
""" Holds locations of images and actual images """ """ Holds locations of images and actual images
def __init__(self, pathcache=None): Don't call directly. Call get_images()
logger.debug("Initializing %s: (pathcache: '%s'", self.__class__.__name__, pathcache) """
def __init__(self):
logger.debug("Initializing %s", self.__class__.__name__)
pathcache = get_config().pathcache
self.pathicons = os.path.join(pathcache, "icons") self.pathicons = os.path.join(pathcache, "icons")
self.pathpreview = os.path.join(pathcache, "preview") self.pathpreview = os.path.join(pathcache, "preview")
self.pathoutput = None self.pathoutput = None
@ -194,7 +223,7 @@ class Images(metaclass=Singleton):
""" Delete the preview files """ """ Delete the preview files """
logger.debug("Deleting previews") logger.debug("Deleting previews")
for item in os.listdir(self.pathpreview): for item in os.listdir(self.pathpreview):
if item.startswith(".gui_preview_") and item.endswith(".jpg"): if item.startswith(".gui_training_preview") and item.endswith(".jpg"):
fullitem = os.path.join(self.pathpreview, item) fullitem = os.path.join(self.pathpreview, item)
logger.debug("Deleting: '%s'", fullitem) logger.debug("Deleting: '%s'", fullitem)
os.remove(fullitem) os.remove(fullitem)
@ -210,34 +239,34 @@ class Images(metaclass=Singleton):
@staticmethod @staticmethod
def get_images(imgpath): def get_images(imgpath):
""" Get the images stored within the given directory """ """ Get the images stored within the given directory """
logger.debug("Getting images: '%s'", imgpath) logger.trace("Getting images: '%s'", imgpath)
if not os.path.isdir(imgpath): if not os.path.isdir(imgpath):
logger.debug("Folder does not exist") logger.debug("Folder does not exist")
return None return None
files = [os.path.join(imgpath, f) files = [os.path.join(imgpath, f)
for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))] for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))]
logger.debug("Image files: %s", files) logger.trace("Image files: %s", files)
return files return files
def load_latest_preview(self): def load_latest_preview(self):
""" Load the latest preview image for extract and convert """ """ Load the latest preview image for extract and convert """
logger.debug("Loading preview image") logger.trace("Loading preview image")
imagefiles = self.get_images(self.pathoutput) imagefiles = self.get_images(self.pathoutput)
if not imagefiles or len(imagefiles) == 1: if not imagefiles or len(imagefiles) == 1:
logger.debug("No preview to display") logger.debug("No preview to display")
self.previewoutput = None self.previewoutput = None
return return
# Get penultimate file so we don't accidently # Get penultimate file so we don't accidentally
# load a file that is being saved # load a file that is being saved
show_file = sorted(imagefiles, key=os.path.getctime)[-2] show_file = sorted(imagefiles, key=os.path.getctime)[-2]
img = Image.open(show_file) img = Image.open(show_file)
img.thumbnail((768, 432)) img.thumbnail((768, 432))
logger.debug("Displaying preview: '%s'", show_file) logger.trace("Displaying preview: '%s'", show_file)
self.previewoutput = (img, ImageTk.PhotoImage(img)) self.previewoutput = (img, ImageTk.PhotoImage(img))
def load_training_preview(self): def load_training_preview(self):
""" Load the training preview images """ """ Load the training preview images """
logger.debug("Loading Training preview images") logger.trace("Loading Training preview images")
imagefiles = self.get_images(self.pathpreview) imagefiles = self.get_images(self.pathpreview)
modified = None modified = None
if not imagefiles: if not imagefiles:
@ -250,7 +279,7 @@ class Images(metaclass=Singleton):
name = os.path.splitext(name)[0] name = os.path.splitext(name)[0]
name = name[name.rfind("_") + 1:].title() name = name[name.rfind("_") + 1:].title()
try: try:
logger.debug("Displaying preview: '%s'", img) logger.trace("Displaying preview: '%s'", img)
size = self.get_current_size(name) size = self.get_current_size(name)
self.previewtrain[name] = [Image.open(img), None, modified] self.previewtrain[name] = [Image.open(img), None, modified]
self.resize_image(name, size) self.resize_image(name, size)
@ -270,20 +299,20 @@ class Images(metaclass=Singleton):
def get_current_size(self, name): def get_current_size(self, name):
""" Return the size of the currently displayed image """ """ Return the size of the currently displayed image """
logger.debug("Getting size: '%s'", name) logger.trace("Getting size: '%s'", name)
if not self.previewtrain.get(name, None): if not self.previewtrain.get(name, None):
return None return None
img = self.previewtrain[name][1] img = self.previewtrain[name][1]
if not img: if not img:
return None return None
logger.debug("Got size: (name: '%s', width: '%s', height: '%s')", logger.trace("Got size: (name: '%s', width: '%s', height: '%s')",
name, img.width(), img.height()) name, img.width(), img.height())
return img.width(), img.height() return img.width(), img.height()
def resize_image(self, name, framesize): def resize_image(self, name, framesize):
""" Resize the training preview image """ Resize the training preview image
based on the passed in frame size """ based on the passed in frame size """
logger.debug("Resizing image: (name: '%s', framesize: %s", name, framesize) logger.trace("Resizing image: (name: '%s', framesize: %s", name, framesize)
displayimg = self.previewtrain[name][0] displayimg = self.previewtrain[name][0]
if framesize: if framesize:
frameratio = float(framesize[0]) / float(framesize[1]) frameratio = float(framesize[0]) / float(framesize[1])
@ -295,7 +324,7 @@ class Images(metaclass=Singleton):
else: else:
scale = framesize[1] / float(displayimg.size[1]) scale = framesize[1] / float(displayimg.size[1])
size = (int(displayimg.size[0] * scale), framesize[1]) size = (int(displayimg.size[0] * scale), framesize[1])
logger.debug("Scaling: (scale: %s, size: %s", scale, size) logger.trace("Scaling: (scale: %s, size: %s", scale, size)
# Hacky fix to force a reload if it happens to find corrupted # Hacky fix to force a reload if it happens to find corrupted
# data, probably due to reading the image whilst it is partially # data, probably due to reading the image whilst it is partially
@ -335,7 +364,9 @@ class ContextMenu(tk.Menu): # pylint: disable=too-many-ancestors
""" Bind the menu to the widget's Right Click event """ """ Bind the menu to the widget's Right Click event """
button = "<Button-2>" if platform.system() == "Darwin" else "<Button-3>" button = "<Button-2>" if platform.system() == "Darwin" else "<Button-3>"
logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class()) logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class())
self.widget.bind(button, lambda event: self.tk_popup(event.x_root, event.y_root, 0)) x_offset = int(34 * get_config().scaling_factor)
self.widget.bind(button,
lambda event: self.tk_popup(event.x_root + x_offset, event.y_root, 0))
def select_all(self): def select_all(self):
""" Select all for Text or Entry widgets """ """ Select all for Text or Entry widgets """
@ -351,16 +382,16 @@ class ContextMenu(tk.Menu): # pylint: disable=too-many-ancestors
class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors
""" The Console out section of the GUI """ """ The Console out section of the GUI """
def __init__(self, parent, debug, tk_vars): def __init__(self, parent, debug):
logger.debug("Initializing %s: (parent: %s, debug: %s, tk_vars: %s)", logger.debug("Initializing %s: (parent: %s, debug: %s)",
self.__class__.__name__, parent, debug, tk_vars) self.__class__.__name__, parent, debug)
ttk.Frame.__init__(self, parent) ttk.Frame.__init__(self, parent)
self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0), self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0),
fill=tk.BOTH, expand=True) fill=tk.BOTH, expand=True)
self.console = tk.Text(self) self.console = tk.Text(self)
rc_menu = ContextMenu(self.console) rc_menu = ContextMenu(self.console)
rc_menu.cm_bind() rc_menu.cm_bind()
self.console_clear = tk_vars['consoleclear'] self.console_clear = get_config().tk_vars['consoleclear']
self.set_console_clear_var_trace() self.set_console_clear_var_trace()
self.debug = debug self.debug = debug
self.build_console() self.build_console()
@ -395,7 +426,7 @@ class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors
sys.stderr = SysOutRouter(console=self.console, out_type="stderr") sys.stderr = SysOutRouter(console=self.console, out_type="stderr")
logger.debug("Redirected console") logger.debug("Redirected console")
def clear(self, *args): def clear(self, *args): # pylint: disable=unused-argument
""" Clear the console output screen """ """ Clear the console output screen """
logger.debug("Clear console") logger.debug("Clear console")
if not self.console_clear.get(): if not self.console_clear.get():
@ -427,3 +458,146 @@ class SysOutRouter():
def flush(): def flush():
""" If flush is forced, send it to normal terminal """ """ If flush is forced, send it to normal terminal """
sys.__stdout__.flush() sys.__stdout__.flush()
class Config():
""" Global configuration settings
Don't call directly. Call get_config()
"""
def __init__(self, cli_opts, scaling_factor, pathcache, statusbar, session):
logger.debug("Initializing %s: (cli_opts: %s, scaling_factor: %s, pathcache: %s, "
"statusbar: %s, session: %s)", self.__class__.__name__, cli_opts,
scaling_factor, pathcache, statusbar, session)
self.cli_opts = cli_opts
self.scaling_factor = scaling_factor
self.pathcache = pathcache
self.statusbar = statusbar
self.serializer = JSONSerializer
self.tk_vars = self.set_tk_vars()
self.command_notebook = None # set in command.py
self.session = session
logger.debug("Initialized %s", self.__class__.__name__)
@property
def command_tabs(self):
""" Return dict of command tab titles with their IDs """
return {self.command_notebook.tab(tab_id, "text").lower(): tab_id
for tab_id in range(0, self.command_notebook.index("end"))}
@staticmethod
def set_tk_vars():
""" TK Variables to be triggered by to indicate
what state various parts of the GUI should be in """
display = tk.StringVar()
display.set(None)
runningtask = tk.BooleanVar()
runningtask.set(False)
actioncommand = tk.StringVar()
actioncommand.set(None)
generatecommand = tk.StringVar()
generatecommand.set(None)
consoleclear = tk.BooleanVar()
consoleclear.set(False)
refreshgraph = tk.BooleanVar()
refreshgraph.set(False)
updatepreview = tk.BooleanVar()
updatepreview.set(False)
tk_vars = {"display": display,
"runningtask": runningtask,
"action": actioncommand,
"generate": generatecommand,
"consoleclear": consoleclear,
"refreshgraph": refreshgraph,
"updatepreview": updatepreview}
logger.debug(tk_vars)
return tk_vars
def load(self, command=None, filename=None):
""" Pop up load dialog for a saved config file """
logger.debug("Loading config: (command: '%s')", command)
if filename:
with open(filename, "r") as cfgfile:
cfg = self.serializer.unmarshal(cfgfile.read())
else:
cfgfile = FileHandler("open", "config").retfile
if not cfgfile:
return
cfg = self.serializer.unmarshal(cfgfile.read())
if not command and len(cfg.keys()) == 1:
command = list(cfg.keys())[0]
opts = self.get_command_options(cfg, command) if command else cfg
if not opts:
return
for cmd, opts in opts.items():
self.set_command_args(cmd, opts)
if command:
self.command_notebook.select(self.command_tabs[command])
self.add_to_recent(cfgfile.name, command)
logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile)
def get_command_options(self, cfg, command):
""" return the saved options for the requested
command, if not loading global options """
opts = cfg.get(command, None)
retval = {command: opts}
if not opts:
self.tk_vars["consoleclear"].set(True)
print("No {} section found in file".format(command))
logger.info("No %s section found in file", command)
retval = None
logger.debug(retval)
return retval
def set_command_args(self, command, options):
""" Pass the saved config items back to the CliOptions """
if not options:
return
for srcopt, srcval in options.items():
optvar = self.cli_opts.get_one_option_variable(command, srcopt)
if not optvar:
continue
optvar.set(srcval)
def save(self, command=None):
""" Save the current GUI state to a config file in json format """
logger.debug("Saving config: (command: '%s')", command)
cfgfile = FileHandler("save", "config").retfile
if not cfgfile:
return
cfg = self.cli_opts.get_option_values(command)
cfgfile.write(self.serializer.marshal(cfg))
cfgfile.close()
self.add_to_recent(cfgfile.name, command)
logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile)
def add_to_recent(self, filename, command):
""" Add to recent files """
recent_filename = os.path.join(self.pathcache, ".recent.json")
logger.debug("Adding to recent files '%s': (%s, %s)", recent_filename, filename, command)
with open(recent_filename, "rb") as inp:
recent_files = self.serializer.unmarshal(inp.read().decode("utf-8"))
logger.debug("Initial recent files: %s", recent_files)
filenames = [recent[0] for recent in recent_files]
if filename in filenames:
idx = filenames.index(filename)
del recent_files[idx]
recent_files.insert(0, (filename, command))
recent_files = recent_files[:20]
logger.debug("Final recent files: %s", recent_files)
recent_json = self.serializer.marshal(recent_files)
with open(recent_filename, "wb") as out:
out.write(recent_json.encode("utf-8"))

View file

@ -1,57 +1,40 @@
#!/usr/bin python3 #!/usr/bin python3
""" Process wrapper for underlying faceswap commands for the GUI """ """ Process wrapper for underlying faceswap commands for the GUI """
import os import os
import logging
import re import re
import signal import signal
from subprocess import PIPE, Popen, TimeoutExpired from subprocess import PIPE, Popen, TimeoutExpired
import sys import sys
import tkinter as tk
from threading import Thread from threading import Thread
from time import time from time import time
import psutil import psutil
from .utils import Images from .utils import get_config, get_images
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class ProcessWrapper(): class ProcessWrapper():
""" Builds command, launches and terminates the underlying """ Builds command, launches and terminates the underlying
faceswap process. Updates GUI display depending on state """ faceswap process. Updates GUI display depending on state """
def __init__(self, statusbar, session=None, pathscript=None, cliopts=None): def __init__(self, pathscript=None):
self.tk_vars = self.set_tk_vars() logger.debug("Initializing %s: (pathscript: %s)", self.__class__.__name__, pathscript)
self.session = session self.tk_vars = get_config().tk_vars
self.set_callbacks()
self.pathscript = pathscript self.pathscript = pathscript
self.cliopts = cliopts
self.command = None self.command = None
self.statusbar = statusbar self.statusbar = get_config().statusbar
self.task = FaceswapControl(self) self.task = FaceswapControl(self)
logger.debug("Initialized %s", self.__class__.__name__)
def set_tk_vars(self): def set_callbacks(self):
""" TK Variables to be triggered by ProcessWrapper to indicate """ Set the tk variable callbacks """
what state various parts of the GUI should be in """ logger.debug("Setting tk variable traces")
display = tk.StringVar() self.tk_vars["action"].trace("w", self.action_command)
display.set(None) self.tk_vars["generate"].trace("w", self.generate_command)
runningtask = tk.BooleanVar()
runningtask.set(False)
actioncommand = tk.StringVar()
actioncommand.set(None)
actioncommand.trace("w", self.action_command)
generatecommand = tk.StringVar()
generatecommand.set(None)
generatecommand.trace("w", self.generate_command)
consoleclear = tk.BooleanVar()
consoleclear.set(False)
return {"display": display,
"runningtask": runningtask,
"action": actioncommand,
"generate": generatecommand,
"consoleclear": consoleclear}
def action_command(self, *args): def action_command(self, *args):
""" The action to perform when the action button is pressed """ """ The action to perform when the action button is pressed """
@ -74,28 +57,30 @@ class ProcessWrapper():
category, command = self.tk_vars["generate"].get().split(",") category, command = self.tk_vars["generate"].get().split(",")
args = self.build_args(category, command=command, generate=True) args = self.build_args(category, command=command, generate=True)
self.tk_vars["consoleclear"].set(True) self.tk_vars["consoleclear"].set(True)
logger.debug(" ".join(args))
print(" ".join(args)) print(" ".join(args))
self.tk_vars["generate"].set(None) self.tk_vars["generate"].set(None)
def prepare(self, category): def prepare(self, category):
""" Prepare the environment for execution """ """ Prepare the environment for execution """
logger.debug("Preparing for execution")
self.tk_vars["runningtask"].set(True) self.tk_vars["runningtask"].set(True)
self.tk_vars["consoleclear"].set(True) self.tk_vars["consoleclear"].set(True)
print("Loading...") print("Loading...")
self.statusbar.status_message.set("Executing - " self.statusbar.status_message.set("Executing - {}.py".format(self.command))
+ self.command + ".py") mode = "indeterminate" if self.command in ("effmpeg", "train") else "determinate"
mode = "indeterminate" if self.command in ("effmpeg",
"train") else "determinate"
self.statusbar.progress_start(mode) self.statusbar.progress_start(mode)
args = self.build_args(category) args = self.build_args(category)
self.tk_vars["display"].set(self.command) self.tk_vars["display"].set(self.command)
logger.debug("Prepared for execution")
return args return args
def build_args(self, category, command=None, generate=False): def build_args(self, category, command=None, generate=False):
""" Build the faceswap command and arguments list """ """ Build the faceswap command and arguments list """
logger.debug("Build cli arguments: (category: %s, command: %s, generate: %s)",
category, command, generate)
command = self.command if not command else command command = self.command if not command else command
script = "{}.{}".format(category, "py") script = "{}.{}".format(category, "py")
pathexecscript = os.path.join(self.pathscript, script) pathexecscript = os.path.join(self.pathscript, script)
@ -103,50 +88,60 @@ class ProcessWrapper():
args = [sys.executable] if generate else [sys.executable, "-u"] args = [sys.executable] if generate else [sys.executable, "-u"]
args.extend([pathexecscript, command]) args.extend([pathexecscript, command])
for cliopt in self.cliopts.gen_cli_arguments(command): cli_opts = get_config().cli_opts
for cliopt in cli_opts.gen_cli_arguments(command):
args.extend(cliopt) args.extend(cliopt)
if command == "train" and not generate: if command == "train" and not generate:
self.set_session_stats(cliopt) self.init_training_session(cliopt)
if command == "train" and not generate: if not generate:
args.append("-gui") # Embed the preview pane args.append("-gui") # Indicate to Faceswap that we are running the GUI
logger.debug("Built cli arguments: (%s)", args)
return args return args
def set_session_stats(self, cliopt): @staticmethod
""" Set the session stats for batchsize and modeldir """ def init_training_session(cliopt):
if cliopt[0] == "-bs": """ Set the session stats for disable logging, model folder and model name """
self.session.stats["batchsize"] = int(cliopt[1]) session = get_config().session
if cliopt[0] == "-t":
session.modelname = cliopt[1].lower().replace("-", "_")
logger.debug("modelname: '%s'", session.modelname)
if cliopt[0] == "-m": if cliopt[0] == "-m":
self.session.modeldir = cliopt[1] session.modeldir = cliopt[1]
logger.debug("modeldir: '%s'", session.modeldir)
def terminate(self, message): def terminate(self, message):
""" Finalise wrapper when process has exited """ """ Finalize wrapper when process has exited """
logger.debug("Terminating Faceswap processes")
self.tk_vars["runningtask"].set(False) self.tk_vars["runningtask"].set(False)
self.statusbar.progress_stop() self.statusbar.progress_stop()
self.statusbar.status_message.set(message) self.statusbar.status_message.set(message)
self.tk_vars["display"].set(None) self.tk_vars["display"].set(None)
Images().delete_preview() get_images().delete_preview()
if self.command == "train": get_config().session.__init__()
self.session.save_session()
self.session.__init__()
self.command = None self.command = None
logger.debug("Terminated Faceswap processes")
print("Process exited.") print("Process exited.")
class FaceswapControl(): class FaceswapControl():
""" Control the underlying Faceswap tasks """ """ Control the underlying Faceswap tasks """
def __init__(self, wrapper): def __init__(self, wrapper):
logger.debug("Initializing %s", self.__class__.__name__)
self.wrapper = wrapper self.wrapper = wrapper
self.statusbar = wrapper.statusbar self.statusbar = get_config().statusbar
self.command = None self.command = None
self.args = None self.args = None
self.process = None self.process = None
self.train_stats = {"iterations": 0, "timestamp": None}
self.consoleregex = { self.consoleregex = {
"loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"), "loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"),
"tqdm": re.compile(r"(\d+%|\d+/\d+|\d+:\d+|\d+\.\d+[a-zA-Z/]+)")} "tqdm": re.compile(r".*?(?P<pct>\d+%).*?(?P<itm>\d+/\d+)\W\["
r"(?P<tme>\d+:\d+<.*),\W(?P<rte>.*)[a-zA-Z/]*\]")}
logger.debug("Initialized %s", self.__class__.__name__)
def execute_script(self, command, args): def execute_script(self, command, args):
""" Execute the requested Faceswap Script """ """ Execute the requested Faceswap Script """
logger.debug("Executing Faceswap: (command: '%s', args: %s)", command, args)
self.command = command self.command = command
kwargs = {"stdout": PIPE, kwargs = {"stdout": PIPE,
"stderr": PIPE, "stderr": PIPE,
@ -156,10 +151,12 @@ class FaceswapControl():
self.process = Popen(args, **kwargs, stdin=PIPE) self.process = Popen(args, **kwargs, stdin=PIPE)
self.thread_stdout() self.thread_stdout()
self.thread_stderr() self.thread_stderr()
logger.debug("Executed Faceswap")
def read_stdout(self): def read_stdout(self):
""" Read stdout from the subprocess. If training, pass the loss """ Read stdout from the subprocess. If training, pass the loss
values to Queue """ values to Queue """
logger.debug("Opening stdout reader")
while True: while True:
try: try:
output = self.process.stdout.readline() output = self.process.stdout.readline()
@ -173,14 +170,19 @@ class FaceswapControl():
if (self.command == "train" and self.capture_loss(output)) or ( if (self.command == "train" and self.capture_loss(output)) or (
self.command != "train" and self.capture_tqdm(output)): self.command != "train" and self.capture_tqdm(output)):
continue continue
if self.command == "train" and output.strip().endswith("saved models"):
logger.debug("Trigger update preview")
self.wrapper.tk_vars["updatepreview"].set(True)
print(output.strip()) print(output.strip())
returncode = self.process.poll() returncode = self.process.poll()
message = self.set_final_status(returncode) message = self.set_final_status(returncode)
self.wrapper.terminate(message) self.wrapper.terminate(message)
logger.debug("Terminated stdout reader. returncode: %s", returncode)
def read_stderr(self): def read_stderr(self):
""" Read stdout from the subprocess. If training, pass the loss """ Read stdout from the subprocess. If training, pass the loss
values to Queue """ values to Queue """
logger.debug("Opening stderr reader")
while True: while True:
try: try:
output = self.process.stderr.readline() output = self.process.stderr.readline()
@ -194,81 +196,125 @@ class FaceswapControl():
if self.command != "train" and self.capture_tqdm(output): if self.command != "train" and self.capture_tqdm(output):
continue continue
print(output.strip(), file=sys.stderr) print(output.strip(), file=sys.stderr)
logger.debug("Terminated stderr reader")
def thread_stdout(self): def thread_stdout(self):
""" Put the subprocess stdout so that it can be read without """ Put the subprocess stdout so that it can be read without
blocking """ blocking """
logger.debug("Threading stdout")
thread = Thread(target=self.read_stdout) thread = Thread(target=self.read_stdout)
thread.daemon = True thread.daemon = True
thread.start() thread.start()
logger.debug("Threaded stdout")
def thread_stderr(self): def thread_stderr(self):
""" Put the subprocess stderr so that it can be read without """ Put the subprocess stderr so that it can be read without
blocking """ blocking """
logger.debug("Threading stderr")
thread = Thread(target=self.read_stderr) thread = Thread(target=self.read_stderr)
thread.daemon = True thread.daemon = True
thread.start() thread.start()
logger.debug("Threaded stderr")
def capture_loss(self, string): def capture_loss(self, string):
""" Capture loss values from stdout """ """ Capture loss values from stdout """
logger.trace("Capturing loss")
if not str.startswith(string, "["): if not str.startswith(string, "["):
logger.trace("Not loss message. Returning False")
return False return False
loss = self.consoleregex["loss"].findall(string) loss = self.consoleregex["loss"].findall(string)
if len(loss) < 2: if len(loss) < 2:
logger.trace("Not loss message. Returning False")
return False return False
self.wrapper.session.add_loss(loss)
message = "" message = ""
for item in loss: for item in loss:
message += "{}: {} ".format(item[0], item[1]) message += "{}: {} ".format(item[0], item[1])
if not message: if not message:
logger.trace("Error creating loss message. Returning False")
return False return False
elapsed = self.wrapper.session.timestats["elapsed"] iterations = self.train_stats["iterations"]
iterations = self.wrapper.session.stats["iterations"]
if iterations == 0:
# Initialize session stats and set initial timestamp
self.train_stats["timestamp"] = time()
if not get_config().session.initialized and iterations > 0:
# Don't initialize session until after the first iteration as state
# file must exist first
get_config().session.initialize_session(is_training=True)
self.wrapper.tk_vars["refreshgraph"].set(True)
iterations += 1
if iterations % 100 == 0:
self.wrapper.tk_vars["refreshgraph"].set(True)
self.train_stats["iterations"] = iterations
elapsed = self.calc_elapsed()
message = "Elapsed: {} Iteration: {} {}".format(elapsed, message = "Elapsed: {} Iteration: {} {}".format(elapsed,
iterations, self.train_stats["iterations"], message)
message)
self.statusbar.progress_update(message, 0, False) self.statusbar.progress_update(message, 0, False)
logger.trace("Succesfully captured loss: %s", message)
return True return True
def calc_elapsed(self):
""" Calculate and format time since training started """
now = time()
elapsed_time = now - self.train_stats["timestamp"]
try:
hrs = int(elapsed_time // 3600)
if hrs < 10:
hrs = "{0:02d}".format(hrs)
mins = "{0:02d}".format((int(elapsed_time % 3600) // 60))
secs = "{0:02d}".format((int(elapsed_time % 3600) % 60))
except ZeroDivisionError:
hrs = "00"
mins = "00"
secs = "00"
return "{}:{}:{}".format(hrs, mins, secs)
def capture_tqdm(self, string): def capture_tqdm(self, string):
""" Capture tqdm output for progress bar """ """ Capture tqdm output for progress bar """
tqdm = self.consoleregex["tqdm"].findall(string) logger.trace("Capturing tqdm")
if len(tqdm) != 5: tqdm = self.consoleregex["tqdm"].match(string)
if not tqdm:
return False return False
tqdm = tqdm.groupdict()
percent = tqdm[0] if any("?" in val for val in tqdm.values()):
processed = tqdm[1] logger.trace("tqdm initializing. Skipping")
processtime = "Elapsed: {} Remaining: {}".format(tqdm[2], tqdm[3]) return True
rate = tqdm[4] processtime = "Elapsed: {} Remaining: {}".format(tqdm["tme"].split("<")[0],
tqdm["tme"].split("<")[1])
message = "{} | {} | {} | {}".format(processtime, message = "{} | {} | {} | {}".format(processtime,
rate, tqdm["rte"],
processed, tqdm["itm"],
percent) tqdm["pct"])
current, total = processed.split("/") current, total = tqdm["itm"].split("/")
position = int((float(current) / float(total)) * 1000) position = int((float(current) / float(total)) * 1000)
self.statusbar.progress_update(message, position, True) self.statusbar.progress_update(message, position, True)
logger.trace("Succesfully captured tqdm message: %s", message)
return True return True
def terminate(self): def terminate(self):
""" Terminate the subprocess """ """ Terminate the subprocess """
if self.command != "train": logger.debug("Terminating wrapper")
if self.command == "train":
logger.debug("Sending Exit Signal")
print("Sending Exit Signal", flush=True) print("Sending Exit Signal", flush=True)
try: try:
now = time() now = time()
if os.name == "nt": if os.name == "nt":
try: try:
logger.debug("Sending carriage return to process")
self.process.communicate(input="\n", timeout=60) self.process.communicate(input="\n", timeout=60)
except TimeoutExpired: except TimeoutExpired:
raise ValueError("Timeout reached sending Exit Signal") raise ValueError("Timeout reached sending Exit Signal")
else: else:
logger.debug("Sending SIGINT to process")
self.process.send_signal(signal.SIGINT) self.process.send_signal(signal.SIGINT)
while True: while True:
timeelapsed = time() - now timeelapsed = time() - now
@ -278,30 +324,37 @@ class FaceswapControl():
raise ValueError("Timeout reached sending Exit Signal") raise ValueError("Timeout reached sending Exit Signal")
return return
except ValueError as err: except ValueError as err:
logger.error("Error terminating process", exc_info=True)
print(err) print(err)
else: else:
logger.debug("Terminating Process...")
print("Terminating Process...") print("Terminating Process...")
children = psutil.Process().children(recursive=True) children = psutil.Process().children(recursive=True)
for child in children: for child in children:
child.terminate() child.terminate()
_, alive = psutil.wait_procs(children, timeout=10) _, alive = psutil.wait_procs(children, timeout=10)
if not alive: if not alive:
logger.debug("Terminated")
print("Terminated") print("Terminated")
return return
logger.debug("Termination timed out. Killing Process...")
print("Termination timed out. Killing Process...") print("Termination timed out. Killing Process...")
for child in alive: for child in alive:
child.kill() child.kill()
_, alive = psutil.wait_procs(alive, timeout=10) _, alive = psutil.wait_procs(alive, timeout=10)
if not alive: if not alive:
logger.debug("Killed")
print("Killed") print("Killed")
else: else:
for child in alive: for child in alive:
print("Process {} survived SIGKILL. " msg = "Process {} survived SIGKILL. Giving up".format(child)
"Giving up".format(child)) logger.debug(msg)
print(msg)
def set_final_status(self, returncode): def set_final_status(self, returncode):
""" Set the status bar output based on subprocess return code """ """ Set the status bar output based on subprocess return code """
logger.debug("Setting final status. returncode: %s", returncode)
if returncode in (0, 3221225786): if returncode in (0, 3221225786):
status = "Ready" status = "Ready"
elif returncode == -15: elif returncode == -15:
@ -311,6 +364,6 @@ class FaceswapControl():
elif returncode == -6: elif returncode == -6:
status = "Aborted - {}.py".format(self.command) status = "Aborted - {}.py".format(self.command)
else: else:
status = "Failed - {}.py. Return Code: {}".format(self.command, status = "Failed - {}.py. Return Code: {}".format(self.command, returncode)
returncode) logger.debug("Set final status: %s", status)
return status return status

92
lib/keypress.py Normal file
View file

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
Source: http://home.wlu.edu/~levys/software/kbhit.py
A Python class implementing KBHIT, the standard keyboard-interrupt poller.
Works transparently on Windows and Posix (Linux, Mac OS X). Doesn't work
with IDLE.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
"""
import os
# Windows
if os.name == "nt":
import msvcrt # pylint: disable=import-error
# Posix (Linux, OS X)
else:
import sys
import termios
import atexit
from select import select
class KBHit:
""" Creates a KBHit object that you can call to do various keyboard things. """
def __init__(self, is_gui=False):
self.is_gui = is_gui
if os.name == "nt" or self.is_gui:
pass
else:
# Save the terminal settings
self.file_desc = sys.stdin.fileno()
self.new_term = termios.tcgetattr(self.file_desc)
self.old_term = termios.tcgetattr(self.file_desc)
# New terminal setting unbuffered
self.new_term[3] = (self.new_term[3] & ~termios.ICANON & ~termios.ECHO)
termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.new_term)
# Support normal-terminal reset at exit
atexit.register(self.set_normal_term)
def set_normal_term(self):
""" Resets to normal terminal. On Windows this is a no-op. """
if os.name == "nt" or self.is_gui:
pass
else:
termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.old_term)
@staticmethod
def getch():
""" Returns a keyboard character after kbhit() has been called.
Should not be called in the same program as getarrow(). """
if os.name == "nt":
return msvcrt.getch().decode("utf-8")
return sys.stdin.read(1)
@staticmethod
def getarrow():
""" Returns an arrow-key code after kbhit() has been called. Codes are
0 : up
1 : right
2 : down
3 : left
Should not be called in the same program as getch(). """
if os.name == "nt":
msvcrt.getch() # skip 0xE0
char = msvcrt.getch()
vals = [72, 77, 80, 75]
else:
char = sys.stdin.read(3)[2]
vals = [65, 67, 66, 68]
return vals.index(ord(char.decode("utf-8")))
@staticmethod
def kbhit():
""" Returns True if keyboard character was hit, False otherwise. """
if os.name == "nt":
return msvcrt.kbhit()
d_r, _, _ = select([sys.stdin], [], [], 0)
return d_r != []

View file

@ -44,9 +44,15 @@ class MultiProcessingLogger(logging.Logger):
class FaceswapFormatter(logging.Formatter): class FaceswapFormatter(logging.Formatter):
""" Override formatter to strip newlines and multiple spaces from logger """ """ Override formatter to strip newlines and multiple spaces from logger
Messages that begin with "R|" should be handled as is
"""
def format(self, record): def format(self, record):
record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r")) if record.msg.startswith("R|"):
record.msg = record.msg[2:]
record.strip_spaces = False
elif record.strip_spaces:
record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r"))
return super().format(record) return super().format(record)
@ -92,7 +98,7 @@ def file_handler(loglevel, logfile, log_format, command):
filename = logfile filename = logfile
else: else:
filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap") filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap")
# Windows has issues sharing the log file with subprocesses, so log GUI seperately # Windows has issues sharing the log file with subprocesses, so log GUI separately
filename += "_gui.log" if command == "gui" else ".log" filename += "_gui.log" if command == "gui" else ".log"
should_rotate = os.path.isfile(filename) should_rotate = os.path.isfile(filename)
@ -152,6 +158,18 @@ def crash_log():
return filename return filename
# Add a flag to logging.LogRecord to not strip formatting from particular records
old_factory = logging.getLogRecordFactory()
def faceswap_logrecord(*args, **kwargs):
record = old_factory(*args, **kwargs)
record.strip_spaces = True
return record
logging.setLogRecordFactory(faceswap_logrecord)
# Set logger class to custom logger # Set logger class to custom logger
logging.setLoggerClass(MultiProcessingLogger) logging.setLoggerClass(MultiProcessingLogger)

0
lib/model/__init__.py Normal file
View file

81
lib/model/initializers.py Normal file
View file

@ -0,0 +1,81 @@
#!/usr/bin/env python3
""" Custom Initializers for faceswap.py
Initializers from:
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
import sys
import inspect
import tensorflow as tf
from keras import initializers
from keras.utils.generic_utils import get_custom_objects
def icnr_keras(shape, dtype=None):
"""
Custom initializer for subpix upscaling
From https://github.com/kostyaev/ICNR
Note: upscale factor is fixed to 2, and the base initializer is fixed to random normal.
"""
# TODO Roll this into ICNR_init when porting GAN 2.2
shape = list(shape)
scale = 2
initializer = tf.keras.initializers.RandomNormal(0, 0.02)
new_shape = shape[:3] + [int(shape[3] / (scale ** 2))]
var_x = initializer(new_shape, dtype)
var_x = tf.transpose(var_x, perm=[2, 0, 1, 3])
var_x = tf.image.resize_nearest_neighbor(var_x, size=(shape[0] * scale, shape[1] * scale))
var_x = tf.space_to_depth(var_x, block_size=scale)
var_x = tf.transpose(var_x, perm=[1, 2, 0, 3])
return var_x
class ICNR(initializers.Initializer): # pylint: disable=invalid-name
'''
ICNR initializer for checkerboard artifact free sub pixel convolution
Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution
https://arxiv.org/pdf/1707.02937.pdf https://distill.pub/2016/deconv-checkerboard/
Parameters:
initializer: initializer used for sub kernels (orthogonal, glorot uniform, etc.)
scale: scale factor of sub pixel convolution (upsampling from 8x8 to 16x16 is scale 2)
Return:
The modified kernel weights
Example:
x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2))
'''
def __init__(self, initializer, scale=2):
self.scale = scale
self.initializer = initializer
def __call__(self, shape, dtype='float32'): # tf needs partition_info=None
shape = list(shape)
if self.scale == 1:
return self.initializer(shape)
new_shape = shape[:3] + [shape[3] // (self.scale ** 2)]
if type(self.initializer) is dict:
self.initializer = initializers.deserialize(self.initializer)
var_x = self.initializer(new_shape, dtype)
var_x = tf.transpose(var_x, perm=[2, 0, 1, 3])
var_x = tf.image.resize_nearest_neighbor(
var_x,
size=(shape[0] * self.scale, shape[1] * self.scale),
align_corners=True)
var_x = tf.space_to_depth(var_x, block_size=self.scale, data_format='NHWC')
var_x = tf.transpose(var_x, perm=[1, 2, 0, 3])
return var_x
def get_config(self):
config = {'scale': self.scale,
'initializer': self.initializer
}
base_config = super(ICNR, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Update initializers into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})

338
lib/model/layers.py Normal file
View file

@ -0,0 +1,338 @@
#!/usr/bin/env python3
""" Custom Layers for faceswap.py
Layers from:
the original https://www.reddit.com/r/deepfakes/ code sample + contribs
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
from __future__ import absolute_import
import sys
import inspect
import tensorflow as tf
import keras.backend as K
from keras.engine import InputSpec, Layer
from keras.utils import conv_utils
from keras.utils.generic_utils import get_custom_objects
from keras import initializers
from keras.layers import ZeroPadding2D
class PixelShuffler(Layer):
""" PixelShuffler layer for Keras
by t-ae: https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 """
# pylint: disable=C0103
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = K.reshape(out, (batch_size, oc, oh, ow))
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = None
width = None
if input_shape[2] is not None:
height = input_shape[2] * self.size[0]
if input_shape[3] is not None:
width = input_shape[3] * self.size[1]
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
retval = (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = None
width = None
if input_shape[1] is not None:
height = input_shape[1] * self.size[0]
if input_shape[2] is not None:
width = input_shape[2] * self.size[1]
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
retval = (input_shape[0],
height,
width,
channels)
return retval
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Scale(Layer):
"""
GAN Custom Scal Layer
Code borrows from https://github.com/flyyufelix/cnn_finetune
"""
def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):
self.axis = axis
self.gamma_init = initializers.get(gamma_init)
self.initial_weights = weights
super(Scale, self).__init__(**kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
# Compatibility with TensorFlow >= 1.0.0
self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))
self.trainable_weights = [self.gamma]
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def call(self, x, mask=None):
return self.gamma * x
def get_config(self):
config = {"axis": self.axis}
base_config = super(Scale, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SubPixelUpscaling(Layer):
# pylint: disable=C0103
""" Sub-pixel convolutional upscaling layer based on the paper "Real-Time
Single Image and Video Super-Resolution Using an Efficient Sub-Pixel
Convolutional Neural Network" (https://arxiv.org/abs/1609.05158).
This layer requires a Convolution2D prior to it, having output filters
computed according to the formula :
filters = k * (scale_factor * scale_factor)
where k = a user defined number of filters (generally larger than 32)
scale_factor = the upscaling factor (generally 2)
This layer performs the depth to space operation on the convolution
filters, and returns a tensor with the size as defined below.
# Example :
```python
# A standard subpixel upscaling block
x = Convolution2D(256, 3, 3, padding="same", activation="relu")(...)
u = SubPixelUpscaling(scale_factor=2)(x)
[Optional]
x = Convolution2D(256, 3, 3, padding="same", activation="relu")(u)
```
In practice, it is useful to have a second convolution layer after the
SubPixelUpscaling layer to speed up the learning process.
However, if you are stacking multiple SubPixelUpscaling blocks,
it may increase the number of parameters greatly, so the Convolution
layer after SubPixelUpscaling layer can be removed.
# Arguments
scale_factor: Upscaling factor.
data_format: Can be None, "channels_first" or "channels_last".
# Input shape
4D tensor with shape:
`(samples, k * (scale_factor * scale_factor) channels, rows, cols)`
if data_format="channels_first"
or 4D tensor with shape:
`(samples, rows, cols, k * (scale_factor * scale_factor) channels)`
if data_format="channels_last".
# Output shape
4D tensor with shape:
`(samples, k channels, rows * scale_factor, cols * scale_factor))`
if data_format="channels_first"
or 4D tensor with shape:
`(samples, rows * scale_factor, cols * scale_factor, k channels)`
if data_format="channels_last".
"""
def __init__(self, scale_factor=2, data_format=None, **kwargs):
super(SubPixelUpscaling, self).__init__(**kwargs)
self.scale_factor = scale_factor
self.data_format = K.normalize_data_format(data_format)
def build(self, input_shape):
pass
def call(self, x, mask=None):
y = self.depth_to_space(x, self.scale_factor, self.data_format)
return y
def compute_output_shape(self, input_shape):
if self.data_format == "channels_first":
b, k, r, c = input_shape
return (b,
k // (self.scale_factor ** 2),
r * self.scale_factor,
c * self.scale_factor)
b, r, c, k = input_shape
return (b,
r * self.scale_factor,
c * self.scale_factor,
k // (self.scale_factor ** 2))
@classmethod
def depth_to_space(cls, ipt, scale, data_format=None):
""" Uses phase shift algorithm to convert channels/depth
for spatial resolution """
if data_format is None:
data_format = K.image_data_format()
data_format = data_format.lower()
ipt = cls._preprocess_conv2d_input(ipt, data_format)
out = tf.depth_to_space(ipt, scale)
out = cls._postprocess_conv2d_output(out, data_format)
return out
@staticmethod
def _postprocess_conv2d_output(x, data_format):
"""Transpose and cast the output from conv2d if needed.
# Arguments
x: A tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
# Returns
A tensor.
"""
if data_format == "channels_first":
x = tf.transpose(x, (0, 3, 1, 2))
if K.floatx() == "float64":
x = tf.cast(x, "float64")
return x
@staticmethod
def _preprocess_conv2d_input(x, data_format):
"""Transpose and cast the input before the conv2d.
# Arguments
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
# Returns
A tensor.
"""
if K.dtype(x) == "float64":
x = tf.cast(x, "float32")
if data_format == "channels_first":
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
x = tf.transpose(x, (0, 2, 3, 1))
return x
def get_config(self):
config = {"scale_factor": self.scale_factor,
"data_format": self.data_format}
base_config = super(SubPixelUpscaling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ReflectionPadding2D(Layer):
def __init__(self, stride=2, kernel_size=5, **kwargs):
'''
# Arguments
stride: stride of following convolution (2)
kernel_size: kernel size of following convolution (5,5)
'''
self.stride = stride
self.kernel_size = kernel_size
super(ReflectionPadding2D, self).__init__(**kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
super(ReflectionPadding2D, self).build(input_shape)
def compute_output_shape(self, input_shape):
""" If you are using "channels_last" configuration"""
input_shape = self.input_spec[0].shape
in_width, in_height = input_shape[2], input_shape[1]
kernel_width, kernel_height = self.kernel_size, self.kernel_size
if (in_height % self.stride == 0):
padding_height = max(kernel_height - self.stride, 0)
else:
padding_height = max(kernel_height - (in_height % self.stride), 0)
if (in_width % self.stride == 0):
padding_width = max(kernel_width - self.stride, 0)
else:
padding_width = max(kernel_width- (in_width % self.stride), 0)
return (input_shape[0],
input_shape[1] + padding_height,
input_shape[2] + padding_width,
input_shape[3])
def call(self, x, mask=None):
input_shape = self.input_spec[0].shape
in_width, in_height = input_shape[2], input_shape[1]
kernel_width, kernel_height = self.kernel_size, self.kernel_size
if (in_height % self.stride == 0):
padding_height = max(kernel_height - self.stride, 0)
else:
padding_height = max(kernel_height - (in_height % self.stride), 0)
if (in_width % self.stride == 0):
padding_width = max(kernel_width - self.stride, 0)
else:
padding_width = max(kernel_width- (in_width % self.stride), 0)
padding_top = padding_height // 2
padding_bot = padding_height - padding_top
padding_left = padding_width // 2
padding_right = padding_width - padding_left
return tf.pad(x, [[0,0],
[padding_top, padding_bot],
[padding_left, padding_right],
[0,0] ],
'REFLECT')
def get_config(self):
config = {'stride': self.stride,
'kernel_size': self.kernel_size}
base_config = super(ReflectionPadding2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Update layers into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})

844
lib/model/losses.py Normal file
View file

@ -0,0 +1,844 @@
#!/usr/bin/env python3
""" Custom Loss Functions for faceswap.py
Losses from:
keras.contrib
dfaker: https://github.com/dfaker/df
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
from __future__ import absolute_import
import keras.backend as K
from keras.layers import Lambda, concatenate
import tensorflow as tf
from tensorflow.contrib.distributions import Beta
from .normalization import InstanceNormalization
class DSSIMObjective():
""" DSSIM Loss Function
Code copy and pasted, with minor ammendments from:
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
MIT License
Copyright (c) 2017 Fariz Rahman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. """
# pylint: disable=C0103
def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0):
"""
Difference of Structural Similarity (DSSIM loss function). Clipped
between 0 and 0.5
Note : You should add a regularization term like a l2 loss in
addition to this one.
Note : In theano, the `kernel_size` must be a factor of the output
size. So 3 could not be the `kernel_size` for an output of 32.
# Arguments
k1: Parameter of the SSIM (default 0.01)
k2: Parameter of the SSIM (default 0.03)
kernel_size: Size of the sliding window (default 3)
max_value: Max value of the output (default 1.0)
"""
self.__name__ = 'DSSIMObjective'
self.kernel_size = kernel_size
self.k1 = k1
self.k2 = k2
self.max_value = max_value
self.c1 = (self.k1 * self.max_value) ** 2
self.c2 = (self.k2 * self.max_value) ** 2
self.dim_ordering = K.image_data_format()
self.backend = K.backend()
@staticmethod
def __int_shape(x):
return K.int_shape(x)
def __call__(self, y_true, y_pred):
# There are additional parameters for this function
# Note: some of the 'modes' for edge behavior do not yet have a
# gradient definition in the Theano tree and cannot be used for
# learning
kernel = [self.kernel_size, self.kernel_size]
y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:]))
y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:]))
patches_pred = self.extract_image_patches(y_pred,
kernel,
kernel,
'valid',
self.dim_ordering)
patches_true = self.extract_image_patches(y_true,
kernel,
kernel,
'valid',
self.dim_ordering)
# Reshape to get the var in the cells
_, w, h, c1, c2, c3 = self.__int_shape(patches_pred)
patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3])
patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3])
# Get mean
u_true = K.mean(patches_true, axis=-1)
u_pred = K.mean(patches_pred, axis=-1)
# Get variance
var_true = K.var(patches_true, axis=-1)
var_pred = K.var(patches_pred, axis=-1)
# Get std dev
covar_true_pred = K.mean(
patches_true * patches_pred, axis=-1) - u_true * u_pred
ssim = (2 * u_true * u_pred + self.c1) * (
2 * covar_true_pred + self.c2)
denom = (K.square(u_true) + K.square(u_pred) + self.c1) * (
var_pred + var_true + self.c2)
ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero
return K.mean((1.0 - ssim) / 2.0)
@staticmethod
def _preprocess_padding(padding):
"""Convert keras' padding to tensorflow's padding.
# Arguments
padding: string, `"same"` or `"valid"`.
# Returns
a string, `"SAME"` or `"VALID"`.
# Raises
ValueError: if `padding` is invalid.
"""
if padding == 'same':
padding = 'SAME'
elif padding == 'valid':
padding = 'VALID'
else:
raise ValueError('Invalid padding:', padding)
return padding
def extract_image_patches(self, x, ksizes, ssizes, padding='same',
data_format='channels_last'):
'''
Extract the patches from an image
# Parameters
x : The input image
ksizes : 2-d tuple with the kernel size
ssizes : 2-d tuple with the strides size
padding : 'same' or 'valid'
data_format : 'channels_last' or 'channels_first'
# Returns
The (k_w,k_h) patches extracted
TF ==> (batch_size,w,h,k_w,k_h,c)
TH ==> (batch_size,w,h,c,k_w,k_h)
'''
kernel = [1, ksizes[0], ksizes[1], 1]
strides = [1, ssizes[0], ssizes[1], 1]
padding = self._preprocess_padding(padding)
if data_format == 'channels_first':
x = K.permute_dimensions(x, (0, 2, 3, 1))
_, _, _, ch_i = K.int_shape(x)
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
padding)
# Reshaping to fit Theano
_, w, h, ch = K.int_shape(patches)
patches = tf.reshape(tf.transpose(tf.reshape(patches,
[-1, w, h,
tf.floordiv(ch, ch_i),
ch_i]),
[0, 1, 2, 4, 3]),
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
if data_format == 'channels_last':
patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
return patches
# <<< START: from Dfaker >>> #
class PenalizedLoss(): # pylint: disable=too-few-public-methods
""" Penalized Loss
from: https://github.com/dfaker/df """
def __init__(self, mask, loss_func, mask_prop=1.0):
self.mask = mask
self.loss_func = loss_func
self.mask_prop = mask_prop
self.mask_as_k_inv_prop = 1-mask_prop
def __call__(self, y_true, y_pred):
# pylint: disable=invalid-name
tro, tgo, tbo = tf.split(y_true, 3, 3)
pro, pgo, pbo = tf.split(y_pred, 3, 3)
tr = tro
tg = tgo
tb = tbo
pr = pro
pg = pgo
pb = pbo
m = self.mask
m = m * self.mask_prop
m += self.mask_as_k_inv_prop
tr *= m
tg *= m
tb *= m
pr *= m
pg *= m
pb *= m
y = tf.concat([tr, tg, tb], 3)
p = tf.concat([pr, pg, pb], 3)
# yo = tf.stack([tro,tgo,tbo],3)
# po = tf.stack([pro,pgo,pbo],3)
return self.loss_func(y, p)
# <<< END: from Dfaker >>> #
# <<< START: from Shoanlu GAN >>> #
def first_order(var_x, axis=1):
""" First Order Function from Shoanlu GAN """
img_nrows = var_x.shape[1]
img_ncols = var_x.shape[2]
if axis == 1:
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :])
if axis == 2:
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :])
return None
def calc_loss(pred, target, loss='l2'):
""" Calculate Loss from Shoanlu GAN """
if loss.lower() == "l2":
return K.mean(K.square(pred - target))
if loss.lower() == "l1":
return K.mean(K.abs(pred - target))
if loss.lower() == "cross_entropy":
return -K.mean(K.log(pred + K.epsilon()) * target +
K.log(1 - pred + K.epsilon()) * (1 - target))
raise ValueError('Recieve an unknown loss type: {}.'.format(loss))
def cyclic_loss(net_g1, net_g2, real1):
""" Cyclic Loss Function from Shoanlu GAN """
fake2 = net_g2(real1)[-1] # fake2 ABGR
fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR
cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR
cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR
loss = calc_loss(cyclic1, real1, loss='l1')
return loss
def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):
""" Adversarial Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
fake = alpha * fake_bgr + (1-alpha) * distorted
if gan_training == "mixup_LSGAN":
dist = Beta(0.2, 0.2)
lam = dist.sample()
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
pred_fake = net_d(concatenate([fake, distorted]))
pred_mixup = net_d(mixup)
loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2")
mixup2 = lam * concatenate([real,
distorted]) + (1 - lam) * concatenate([fake_bgr,
distorted])
pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
pred_mixup2 = net_d(mixup2)
loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2")
elif gan_training == "relativistic_avg_LSGAN":
real_pred = net_d(concatenate([real, distorted]))
fake_pred = net_d(concatenate([fake, distorted]))
loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2
loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2
loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred)))
fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
K.ones_like(fake_pred2)))/2
loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
K.zeros_like(fake_pred2)))/2
loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
K.zeros_like(fake_pred2)))/2
loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
K.ones_like(fake_pred2)))/2
else:
raise ValueError("Receive an unknown GAN training method: {gan_training}")
return loss_d, loss_g
def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights):
""" Reconstruction Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
loss_g = 0
loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1")
loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real)))
for out in model_outputs[:-1]:
out_size = out.get_shape().as_list()
resized_real = tf.image.resize_images(real, out_size[1:3])
loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1")
return loss_g
def edge_loss(real, fake_abgr, mask_eyes, **weights):
""" Edge Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
loss_g = 0
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1),
first_order(real, axis=1), "l1")
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2),
first_order(real, axis=2), "l1")
shape_mask_eyes = mask_eyes.get_shape().as_list()
resized_mask_eyes = tf.image.resize_images(mask_eyes,
[shape_mask_eyes[1]-1, shape_mask_eyes[2]-1])
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
(first_order(fake_bgr, axis=1) -
first_order(real, axis=1))))
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
(first_order(fake_bgr, axis=2) -
first_order(real, axis=2))))
return loss_g
def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights):
""" Perceptual Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
fake = alpha * fake_bgr + (1-alpha) * distorted
def preprocess_vggface(var_x):
var_x = (var_x + 1) / 2 * 255 # channel order: BGR
var_x -= [91.4953, 103.8827, 131.0912]
return var_x
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
dist = Beta(0.2, 0.2)
lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1.
mixup = lam*fake_bgr + (1-lam)*fake
fake_sz224 = tf.image.resize_images(mixup, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)
# Apply instance norm on VGG(ResNet) features
# From MUNIT https://github.com/NVlabs/MUNIT
loss_g = 0
def instnorm():
return InstanceNormalization()
loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
instnorm()(real_feat7), "l2")
loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
instnorm()(real_feat28), "l2")
loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
instnorm()(real_feat55), "l2")
loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
instnorm()(real_feat112), "l2")
return loss_g
# <<< END: from Shoanlu GAN >>> #
def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0):
'''
generalized function used to return a large variety of mathematical loss functions
primary benefit is smooth, differentiable version of L1 loss
Barron, J. A More General Robust Loss Function
https://arxiv.org/pdf/1701.03077.pdf
Parameters:
a: penalty factor. larger number give larger weight to large deviations
c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 )
Return:
a loss value from the results of function(y_pred - y_true)
Example:
a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss
a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss
'''
x = y_pred - y_true
loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 )
return K.mean(loss, axis=-1) * c
def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0):
h = c
w = c
x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0)
loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w))
loss += 1e-10
return K.mean(loss, axis=-1)
def gradient_loss(y_true, y_pred):
'''
Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions.
These gradients are then compared between the ground truth and the predicted image and the difference is taken.
The difference used is a smooth L1 norm ( approximate to MAE but differable at zero )
When used as a loss, its minimization will result in predicted images approaching the same level of sharpness
/ blurriness as the ground truth.
TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014
(http://downloads.hindawi.com/journals/mpe/2014/790547.pdf)
Parameters:
y_true: The predicted frames at each scale.
y_true: The ground truth frames at each scale
Return:
The GD loss.
'''
assert 4 == K.ndim(y_true)
y_true.set_shape([None,80,80,3])
y_pred.set_shape([None,80,80,3])
TV_weight = 1.0
TV2_weight = 1.0
loss = 0.0
def diff_x(X):
Xleft = X[:, :, 1, :] - X[:, :, 0, :]
Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2)
Xright = X[:, :, -1, :] - X[:, :, -2, :]
Xout = [Xleft] + Xinner + [Xright]
Xout = tf.stack(Xout,axis=2)
return Xout * 0.5
def diff_y(X):
Xtop = X[:, 1, :, :] - X[:, 0, :, :]
Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1)
Xbot = X[:, -1, :, :] - X[:, -2, :, :]
Xout = [Xtop] + Xinner + [Xbot]
Xout = tf.stack(Xout,axis=1)
return Xout * 0.5
def diff_xx(X):
Xleft = X[:, :, 1, :] + X[:, :, 0, :]
Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2)
Xright = X[:, :, -1, :] + X[:, :, -2, :]
Xout = [Xleft] + Xinner + [Xright]
Xout = tf.stack(Xout,axis=2)
return Xout - 2.0 * X
def diff_yy(X):
Xtop = X[:, 1, :, :] + X[:, 0, :, :]
Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1)
Xbot = X[:, -1, :, :] + X[:, -2, :, :]
Xout = [Xtop] + Xinner + [Xbot]
Xout = tf.stack(Xout,axis=1)
return Xout - 2.0 * X
def diff_xy(X):
#xout1
top_left = X[:, 1, 1, :]+X[:, 0, 0, :]
inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1)
bot_left = X[:, -1, 1, :]+X[:, -2, 0, :]
X_left = [top_left] + inner_left + [bot_left]
X_left = tf.stack(X_left, axis=1)
top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :]
mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1)
bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :]
X_mid = [top_mid] + mid_mid + [bot_mid]
X_mid = tf.stack(X_mid, axis=1)
top_right = X[:, 1, -1, :]+X[:, 0, -2, :]
inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1)
bot_right = X[:, -1, -1, :]+X[:, -2, -2, :]
X_right = [top_right] + inner_right + [bot_right]
X_right = tf.stack(X_right, axis=1)
X_mid = tf.unstack(X_mid, axis=2)
Xout1 = [X_left] + X_mid + [X_right]
Xout1 = tf.stack(Xout1, axis=2)
#Xout2
top_left = X[:, 0, 1, :]+X[:, 1, 0, :]
inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1)
bot_left = X[:, -2, 1, :]+X[:, -1, 0, :]
X_left = [top_left] + inner_left + [bot_left]
X_left = tf.stack(X_left, axis=1)
top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :]
mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1)
bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :]
X_mid = [top_mid] + mid_mid + [bot_mid]
X_mid = tf.stack(X_mid, axis=1)
top_right = X[:, 0, -1, :]+X[:, 1, -2, :]
inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1)
bot_right = X[:, -2, -1, :]+X[:, -1, -2, :]
X_right = [top_right] + inner_right + [bot_right]
X_right = tf.stack(X_right, axis=1)
X_mid = tf.unstack(X_mid, axis=2)
Xout2 = [X_left] + X_mid + [X_right]
Xout2 = tf.stack(Xout2, axis=2)
return (Xout1 - Xout2) * 0.25
loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) +
generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) )
loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) +
generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) +
2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) )
return loss / ( TV_weight + TV2_weight )
def scharr_edges(image, magnitude):
'''
Returns a tensor holding modified Scharr edge maps.
Arguments:
image: Image tensor with shape [batch_size, h, w, d] and type float32.
The image(s) must be 2x2 or larger.
magnitude: Boolean to determine if the edge magnitude or edge direction is returned
Returns:
Tensor holding edge maps for each channel. Returns a tensor with shape
[batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]],
[dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter.
'''
# Define vertical and horizontal Scharr filters.
static_image_shape = image.get_shape()
image_shape = tf.shape(image)
'''
#modified 3x3 Scharr
kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]],
[[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]]
'''
# 5x5 Scharr
kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]],
[[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]]
num_kernels = len(kernels)
kernels = numpy.transpose(numpy.asarray(kernels), (1, 2, 0))
kernels = numpy.expand_dims(kernels, -2) / numpy.sum(numpy.abs(kernels))
kernels_tf = tf.constant(kernels, dtype=image.dtype)
kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters')
# Use depth-wise convolution to calculate edge maps per channel.
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
padded = tf.pad(image, pad_sizes, mode='REFLECT')
# Output tensor has shape [batch_size, h, w, d * num_kernels].
strides = [1, 1, 1, 1]
output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID')
# Reshape to [batch_size, h, w, d, num_kernels].
shape = tf.concat([image_shape, [num_kernels]], 0)
output = tf.reshape(output, shape=shape)
output.set_shape(static_image_shape.concatenate([num_kernels]))
if magnitude: # magnitude of edges
output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1))
else: # direction of edges
output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1])))
return output
def gmsd_loss(y_true,y_pred):
'''
Improved image quality metric over MS-SSIM with easier calc
http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm
https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf
'''
true_edge_mag = scharr_edges(y_true,True)
pred_edge_mag = scharr_edges(y_pred,True)
c = 0.002
upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c
lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c
GMS = tf.div(upper,lower)
_mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True)
GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1]
return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded
def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)):
'''
Computes the MS-SSIM between img1 and img2.
This function assumes that `img1` and `img2` are image batches, i.e. the last
three dimensions are [height, width, channels].
Note: The true SSIM is only defined on grayscale. This function does not
perform any colorspace transform. (If input is already YUV, then it will
compute YUV SSIM average.)
Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
structural similarity for image quality assessment." Signals, Systems and
Computers, 2004.
Arguments:
img1: First image batch.
img2: Second image batch. Must have the same rank as img1.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
power_factors: Iterable of weights for each of the scales. The number of
scales used is the length of the list. Index 0 is the unscaled
resolution's weight and each increasing scale corresponds to the image
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
0.1333), which are the values obtained in the original paper.
Returns:
A tensor containing an MS-SSIM value for each image in batch. The values
are in range [0, 1]. Returns a tensor with shape:
broadcast(img1.shape[:-3], img2.shape[:-3]).
'''
def _verify_compatible_image_shapes(img1, img2):
'''
Checks if two image tensors are compatible for applying SSIM or PSNR.
This function checks if two sets of images have ranks at least 3, and if the
last three dimensions match.
Args:
img1: Tensor containing the first image batch.
img2: Tensor containing the second image batch.
Returns:
A tuple containing: the first tensor shape, the second tensor shape, and a
list of control_flow_ops.Assert() ops implementing the checks.
Raises:
ValueError: When static shape check fails.
'''
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].assert_is_compatible_with(shape2[-3:])
if shape1.ndims is not None and shape2.ndims is not None:
for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])):
if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2))
# Now assign shape tensors.
shape1, shape2 = tf.shape_n([img1, img2])
# TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable.
checks = []
checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10))
checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10))
return shape1, shape2, checks
def _ssim_per_channel(img1, img2, max_val=1.0):
'''
Computes SSIM index between img1 and img2 per color channel.
This function matches the standard SSIM implementation from:
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
quality assessment: from error visibility to structural similarity. IEEE
transactions on image processing.
Details:
- 11x11 Gaussian filter of width 1.5 is used.
- k1 = 0.01, k2 = 0.03 as in the original paper.
Args:
img1: First image batch.
img2: Second image batch.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
Returns:
A pair of tensors containing and channel-wise SSIM and contrast-structure
values. The shape is [..., channels].
'''
def _fspecial_gauss(size, sigma):
'''
Function to mimic the 'fspecial' gaussian MATLAB function.
'''
size = tf.convert_to_tensor(size, 'int32')
sigma = tf.convert_to_tensor(sigma)
coords = tf.cast(tf.range(size), sigma.dtype)
coords -= tf.cast(size - 1, sigma.dtype) / 2.0
g = tf.square(coords)
g *= -0.5 / tf.square(sigma)
g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1])
g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
g = tf.nn.softmax(g)
return tf.reshape(g, shape=[size, size, 1, 1])
def _ssim_helper(x, y, max_val, kernel, compensation=1.0):
'''
Helper function for computing SSIM.
SSIM estimates covariances with weighted sums. The default parameters
use a biased estimate of the covariance:
Suppose `reducer` is a weighted sum, then the mean estimators are
\mu_x = \sum_i w_i x_i,
\mu_y = \sum_i w_i y_i,
where w_i's are the weighted-sum weights, and covariance estimator is
cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
with assumption \sum_i w_i = 1. This covariance estimator is biased, since
E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
For SSIM measure with unbiased covariance estimators, pass as `compensation`
argument (1 - \sum_i w_i ^ 2).
Arguments:
x: First set of images.
y: Second set of images.
reducer: Function that computes 'local' averages from set of images.
For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]),
and for convolutional version, this is usually tf.nn.avg_pool or
tf.nn.conv2d with weighted-sum kernel.
max_val: The dynamic range (i.e., the difference between the maximum
possible allowed value and the minimum allowed value).
compensation: Compensation factor. See above.
Returns:
A pair containing the luminance measure, and the contrast-structure measure.
'''
def reducer(x, kernel):
shape = tf.shape(x)
x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0))
y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0))
_SSIM_K1 = 0.01
_SSIM_K2 = 0.03
c1 = (_SSIM_K1 * max_val) ** 2
c2 = (_SSIM_K2 * max_val) ** 2
# SSIM luminance measure is
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
mean0 = reducer(x, kernel)
mean1 = reducer(y, kernel)
num0 = mean0 * mean1 * 2.0
den0 = tf.square(mean0) + tf.square(mean1)
luminance = (num0 + c1) / (den0 + c1)
# SSIM contrast-structure measure is
# (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
# Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
# cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
# = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
num1 = reducer(x * y, kernel) * 2.0
den1 = reducer(tf.square(x) + tf.square(y), kernel)
c2 *= compensation
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
# SSIM score is the product of the luminance and contrast-structure measures.
return luminance, cs
filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due
filter_sigma = tf.constant(1.5, dtype=img1.dtype)
shape1, shape2 = tf.shape_n([img1, img2])
checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8),
tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)]
# Enforce the check to run before computation.
with tf.control_dependencies(checks):
img1 = tf.identity(img1)
# TODO(sjhwang): Try to cache kernels and compensation factor.
kernel = _fspecial_gauss(filter_size, filter_sigma)
kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1])
# The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
# but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
compensation = 1.0
# TODO(sjhwang): Try FFT.
# TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
# 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter.
luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation)
# Average over the second and the third from the last: height, width.
axes = tf.constant([-3, -2], dtype='int32')
ssim_val = tf.reduce_mean(luminance * cs, axes)
cs = tf.reduce_mean(cs, axes)
return ssim_val, cs
def do_pad(images, remainder):
padding = tf.expand_dims(remainder, -1)
padding = tf.pad(padding, [[1, 0], [1, 0]])
return [tf.pad(x, padding, mode='SYMMETRIC') for x in images]
# Shape checking.
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].merge_with(shape2[-3:])
with tf.name_scope(None, 'MS-SSIM', [img1, img2]):
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
with tf.control_dependencies(checks):
img1 = tf.identity(img1)
# Need to convert the images to float32. Scale max_val accordingly so that
# SSIM is computed correctly.
max_val = tf.cast(max_val, img1.dtype)
max_val = tf.image.convert_image_dtype(max_val, 'float32')
img1 = tf.image.convert_image_dtype(img1, 'float32')
img2 = tf.image.convert_image_dtype(img2, 'float32')
imgs = [img1, img2]
shapes = [shape1, shape2]
# img1 and img2 are assumed to be a (multi-dimensional) batch of
# 3-dimensional images (height, width, channels). `heads` contain the batch
# dimensions, and `tails` contain the image dimensions.
heads = [s[:-3] for s in shapes]
tails = [s[-3:] for s in shapes]
divisor = [1, 2, 2, 1]
divisor_tensor = tf.constant(divisor[1:], dtype='int32')
mcs = []
for k in range(len(power_factors)):
with tf.name_scope(None, 'Scale%d' % k, imgs):
if k > 0:
# Avg pool takes rank 4 tensors. Flatten leading dimensions.
flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)]
remainder = tails[0] % divisor_tensor
need_padding = tf.reduce_any(tf.not_equal(remainder, 0))
padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder),
lambda: flat_imgs)
downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID')
for x in padded]
tails = [x[1:] for x in tf.shape_n(downscaled)]
imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)]
# Overwrite previous ssim value since we only need the last one.
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val)
mcs.append(tf.nn.relu(cs))
# Remove the cs score for the last scale. In the MS-SSIM calculation,
# we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
mcs.pop() # Remove the cs score for the last scale.
mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1)
# Take weighted geometric mean across the scale axis.
ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1])
return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
def ms_ssim_loss(y_true,y_pred):
MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1)
return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded

101
lib/model/masks.py Normal file
View file

@ -0,0 +1,101 @@
#!/usr/bin/env python3
""" Masks functions for faceswap.py
Masks from:
dfaker: https://github.com/dfaker/df"""
import logging
import cv2
import numpy as np
from lib.umeyama import umeyama
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def dfaker(landmarks, face, channels=4):
""" Dfaker model mask
Embeds the mask into the face alpha channel
channels: 1, 3 or 4:
1 - Return a single channel mask
3 - Return a 3 channel mask
4 - Return the original image with the mask in the alpha channel
"""
padding = int(face.shape[0] * 0.1875)
coverage = face.shape[0] - (padding * 2)
logger.trace("face_shape: %s, coverage: %s, landmarks: %s", face.shape, coverage, landmarks)
mat = umeyama(landmarks[17:], True)[0:2]
mat = np.array(mat.ravel()).reshape(2, 3)
mat = mat * coverage
mat[:, 2] += padding
points = np.array(landmarks).reshape((-1, 2))
facepoints = np.array(points).reshape((-1, 2))
mask = np.zeros_like(face, dtype=np.uint8)
hull = cv2.convexHull(facepoints.astype(int)) # pylint: disable=no-member
hull = cv2.transform(hull.reshape(1, -1, 2), # pylint: disable=no-member
mat).reshape(-1, 2).astype(int)
cv2.fillConvexPoly(mask, hull, (255, 255, 255)) # pylint: disable=no-member
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # pylint: disable=no-member
mask = cv2.dilate(mask, # pylint: disable=no-member
kernel,
iterations=1,
borderType=cv2.BORDER_REFLECT) # pylint: disable=no-member
mask = mask[:, :, :1]
return merge_mask(face, mask, channels)
def dfl_full(landmarks, face, channels=4):
""" DFL Face Full Mask
channels: 1, 3 or 4:
1 - Return a single channel mask
3 - Return a 3 channel mask
4 - Return the original image with the mask in the alpha channel
"""
logger.trace("face_shape: %s, landmarks: %s", face.shape, landmarks)
mask = np.zeros(face.shape[0:2] + (1, ), dtype=np.float32)
jaw = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
landmarks[0:17], # jawline
landmarks[48:68], # mouth
[landmarks[0]], # temple
[landmarks[8]], # chin
[landmarks[16]]))) # temple
nose_ridge = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
landmarks[27:31], # nose line
[landmarks[33]]))) # nose point
eyes = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
landmarks[17:27], # eyebrows
[landmarks[0]], # temple
[landmarks[27]], # nose top
[landmarks[16]], # temple
[landmarks[33]]))) # nose point
cv2.fillConvexPoly(mask, jaw, (255, 255, 255)) # pylint: disable=no-member
cv2.fillConvexPoly(mask, nose_ridge, (255, 255, 255)) # pylint: disable=no-member
cv2.fillConvexPoly(mask, eyes, (255, 255, 255)) # pylint: disable=no-member
return merge_mask(face, mask, channels)
def merge_mask(image, mask, channels):
""" Return the mask in requested shape """
logger.trace("image_shape: %s, mask_shape: %s, channels: %s",
image.shape, mask.shape, channels)
assert channels in (1, 3, 4), "Channels should be 1, 3 or 4"
assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel"
if channels == 3:
retval = np.tile(mask, 3)
elif channels == 4:
retval = np.concatenate((image, mask), -1)
else:
retval = mask
logger.trace("Final mask shape: %s", retval.shape)
return retval

279
lib/model/nn_blocks.py Normal file
View file

@ -0,0 +1,279 @@
#!/usr/bin/env python3
""" Neural Network Blocks for faceswap.py
Blocks from:
the original https://www.reddit.com/r/deepfakes/ code sample + contribs
dfaker: https://github.com/dfaker/df
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
import logging
import tensorflow as tf
import keras.backend as K
from keras.layers import (add, Add, BatchNormalization, concatenate, Lambda, regularizers,
Permute, Reshape, SeparableConv2D, Softmax, UpSampling2D)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation
from keras.initializers import he_uniform, Constant
from .initializers import ICNR
from .layers import PixelShuffler, Scale, SubPixelUpscaling, ReflectionPadding2D
from .normalization import GroupNormalization, InstanceNormalization
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class NNBlocks():
""" Blocks to use for creating models """
def __init__(self, use_subpixel=False, use_icnr_init=False, use_reflect_padding=False):
logger.debug("Initializing %s: (use_subpixel: %s, use_icnr_init: %s, use_reflect_padding: %s",
self.__class__.__name__, use_subpixel, use_icnr_init, use_reflect_padding)
self.use_subpixel = use_subpixel
self.use_icnr_init = use_icnr_init
self.use_reflect_padding = use_reflect_padding
logger.debug("Initialized %s", self.__class__.__name__)
@staticmethod
def update_kwargs(kwargs):
""" Set the default kernel initializer to he_uniform() """
kwargs["kernel_initializer"] = kwargs.get("kernel_initializer", he_uniform())
return kwargs
# <<< Original Model Blocks >>> #
def conv(self, inp, filters, kernel_size=5, strides=2, padding='same', use_instance_norm=False, res_block_follows=False, **kwargs):
""" Convolution Layer"""
logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, use_instance_norm: %s, "
"kwargs: %s", inp, filters, kernel_size, strides, use_instance_norm, kwargs)
kwargs = self.update_kwargs(kwargs)
if self.use_reflect_padding:
inp = ReflectionPadding2D(stride=strides, kernel_size=kernel_size)(inp)
padding = 'valid'
var_x = Conv2D(filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
**kwargs)(inp)
if use_instance_norm:
var_x = InstanceNormalization()(var_x)
if not res_block_follows:
var_x = LeakyReLU(0.1)(var_x)
return var_x
def upscale(self, inp, filters, kernel_size=3, padding= 'same', use_instance_norm=False, res_block_follows=False, **kwargs):
""" Upscale Layer """
logger.debug("inp: %s, filters: %s, kernel_size: %s, use_instance_norm: %s, kwargs: %s",
inp, filters, kernel_size, use_instance_norm, kwargs)
kwargs = self.update_kwargs(kwargs)
if self.use_reflect_padding:
inp = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(inp)
padding = 'valid'
if self.use_icnr_init:
kwargs["kernel_initializer"] = ICNR(initializer=kwargs["kernel_initializer"])
var_x = Conv2D(filters * 4,
kernel_size=kernel_size,
padding=padding,
**kwargs)(inp)
if use_instance_norm:
var_x = InstanceNormalization()(var_x)
if not res_block_follows:
var_x = LeakyReLU(0.1)(var_x)
if self.use_subpixel:
var_x = SubPixelUpscaling()(var_x)
else:
var_x = PixelShuffler()(var_x)
return var_x
# <<< DFaker Model Blocks >>> #
def res_block(self, inp, filters, kernel_size=3, padding= 'same', **kwargs):
""" Residual block """
logger.debug("inp: %s, filters: %s, kernel_size: %s, kwargs: %s",
inp, filters, kernel_size, kwargs)
kwargs = self.update_kwargs(kwargs)
var_x = LeakyReLU(alpha=0.2)(inp)
if self.use_reflect_padding:
var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
padding = 'valid'
var_x = Conv2D(filters,
kernel_size=kernel_size,
padding=padding,
**kwargs)(var_x)
var_x = LeakyReLU(alpha=0.2)(var_x)
if self.use_reflect_padding:
var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
padding = 'valid'
var_x = Conv2D(filters,
kernel_size=kernel_size,
padding=padding,
**kwargs)(var_x)
var_x = Scale(gamma_init=Constant(value=0.1))(var_x)
var_x = Add()([var_x, inp])
var_x = LeakyReLU(alpha=0.2)(var_x)
return var_x
# <<< Unbalanced Model Blocks >>> #
def conv_sep(self, inp, filters, kernel_size=5, strides=2, **kwargs):
""" Seperable Convolution Layer """
logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s",
inp, filters, kernel_size, strides, kwargs)
kwargs = self.update_kwargs(kwargs)
var_x = SeparableConv2D(filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
**kwargs)(inp)
var_x = Activation("relu")(var_x)
return var_x
# <<< GAN V2.2 Blocks >>> #
# TODO Merge these into NNBLock class when porting GAN2.2
# Gan Constansts:
GAN22_CONV_INIT = "he_normal"
GAN22_REGULARIZER = 1e-4
# Gan Blocks:
def normalization(inp, norm='none', group='16'):
""" GAN Normalization """
if norm == 'layernorm':
var_x = GroupNormalization(group=group)(inp)
elif norm == 'batchnorm':
var_x = BatchNormalization()(inp)
elif norm == 'groupnorm':
var_x = GroupNormalization(group=16)(inp)
elif norm == 'instancenorm':
var_x = InstanceNormalization()(inp)
elif norm == 'hybrid':
if group % 2 == 1:
raise ValueError("Output channels must be an even number for hybrid norm, "
"received {}.".format(group))
filt = group
var_x_0 = Lambda(lambda var_x: var_x[..., :filt // 2])(var_x)
var_x_1 = Lambda(lambda var_x: var_x[..., filt // 2:])(var_x)
var_x_0 = Conv2D(filt // 2,
kernel_size=1,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=GAN22_CONV_INIT)(var_x_0)
var_x_1 = InstanceNormalization()(var_x_1)
var_x = concatenate([var_x_0, var_x_1], axis=-1)
else:
var_x = inp
return var_x
def upscale_ps(inp, filters, initializer, use_norm=False, norm="none"):
""" GAN Upscaler - Pixel Shuffler """
var_x = Conv2D(filters * 4,
kernel_size=3,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=initializer,
padding="same")(inp)
var_x = LeakyReLU(0.2)(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
var_x = PixelShuffler()(var_x)
return var_x
def upscale_nn(inp, filters, use_norm=False, norm="none"):
""" GAN Neural Network """
var_x = UpSampling2D()(inp)
var_x = reflect_padding_2d(var_x, 1)
var_x = Conv2D(filters,
kernel_size=3,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer="he_normal")(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
return var_x
def reflect_padding_2d(inp, pad=1):
""" GAN Reflect Padding (2D) """
var_x = Lambda(lambda var_x: tf.pad(var_x,
[[0, 0], [pad, pad], [pad, pad], [0, 0]],
mode="REFLECT"))(inp)
return var_x
def conv_gan(inp, filters, use_norm=False, strides=2, norm='none'):
""" GAN Conv Block """
var_x = Conv2D(filters,
kernel_size=3,
strides=strides,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=GAN22_CONV_INIT,
use_bias=False,
padding="same")(inp)
var_x = Activation("relu")(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
return var_x
def conv_d_gan(inp, filters, use_norm=False, norm='none'):
""" GAN Discriminator Conv Block """
var_x = inp
var_x = Conv2D(filters,
kernel_size=4,
strides=2,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=GAN22_CONV_INIT,
use_bias=False,
padding="same")(var_x)
var_x = LeakyReLU(alpha=0.2)(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
return var_x
def res_block_gan(inp, filters, use_norm=False, norm='none'):
""" GAN Res Block """
var_x = Conv2D(filters,
kernel_size=3,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=GAN22_CONV_INIT,
use_bias=False,
padding="same")(inp)
var_x = LeakyReLU(alpha=0.2)(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
var_x = Conv2D(filters,
kernel_size=3,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
kernel_initializer=GAN22_CONV_INIT,
use_bias=False,
padding="same")(var_x)
var_x = add([var_x, inp])
var_x = LeakyReLU(alpha=0.2)(var_x)
var_x = normalization(var_x, norm, filters) if use_norm else var_x
return var_x
def self_attn_block(inp, n_c, squeeze_factor=8):
""" GAN Self Attention Block
Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow
"""
msg = "Input channels must be >= {}, recieved nc={}".format(squeeze_factor, n_c)
assert n_c // squeeze_factor > 0, msg
var_x = inp
shape_x = var_x.get_shape().as_list()
var_f = Conv2D(n_c // squeeze_factor, 1,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
var_g = Conv2D(n_c // squeeze_factor, 1,
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
var_h = Conv2D(n_c, 1, kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
shape_f = var_f.get_shape().as_list()
shape_g = var_g.get_shape().as_list()
shape_h = var_h.get_shape().as_list()
flat_f = Reshape((-1, shape_f[-1]))(var_f)
flat_g = Reshape((-1, shape_g[-1]))(var_g)
flat_h = Reshape((-1, shape_h[-1]))(var_h)
var_s = Lambda(lambda var_x: K.batch_dot(var_x[0],
Permute((2, 1))(var_x[1])))([flat_g, flat_f])
beta = Softmax(axis=-1)(var_s)
var_o = Lambda(lambda var_x: K.batch_dot(var_x[0], var_x[1]))([beta, flat_h])
var_o = Reshape(shape_x[1:])(var_o)
var_o = Scale()(var_o)
out = add([var_o, inp])
return out

289
lib/model/normalization.py Normal file
View file

@ -0,0 +1,289 @@
#!/usr/bin/env python3
""" Normaliztion methods for faceswap.py
Code from:
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
import sys
import inspect
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
def to_list(inp):
""" Convert to list """
if not isinstance(inp, (list, tuple)):
return [inp]
return list(inp)
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast
Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
self.beta = None
self.gamma = None
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if self.axis is not None:
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class GroupNormalization(Layer):
""" Group Normalization
from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
def __init__(self, axis=-1,
gamma_init='one', beta_init='zero',
gamma_regularizer=None, beta_regularizer=None,
epsilon=1e-6,
group=32,
data_format=None,
**kwargs):
self.beta = None
self.gamma = None
super(GroupNormalization, self).__init__(**kwargs)
self.axis = to_list(axis)
self.gamma_init = initializers.get(gamma_init)
self.beta_init = initializers.get(beta_init)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.epsilon = epsilon
self.group = group
self.data_format = K.normalize_data_format(data_format)
self.supports_masking = True
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
shape = [1 for _ in input_shape]
if self.data_format == 'channels_last':
channel_axis = -1
shape[channel_axis] = input_shape[channel_axis]
elif self.data_format == 'channels_first':
channel_axis = 1
shape[channel_axis] = input_shape[channel_axis]
# for i in self.axis:
# shape[i] = input_shape[i]
self.gamma = self.add_weight(shape=shape,
initializer=self.gamma_init,
regularizer=self.gamma_regularizer,
name='gamma')
self.beta = self.add_weight(shape=shape,
initializer=self.beta_init,
regularizer=self.beta_regularizer,
name='beta')
self.built = True
def call(self, inputs, mask=None):
input_shape = K.int_shape(inputs)
if len(input_shape) != 4 and len(input_shape) != 2:
raise ValueError('Inputs should have rank ' +
str(4) + " or " + str(2) +
'; Received input shape:', str(input_shape))
if len(input_shape) == 4:
if self.data_format == 'channels_last':
batch_size, height, width, channels = input_shape
if batch_size is None:
batch_size = -1
if channels < self.group:
raise ValueError('Input channels should be larger than group size' +
'; Received input channels: ' + str(channels) +
'; Group size: ' + str(self.group))
var_x = K.reshape(inputs, (batch_size,
height,
width,
self.group,
channels // self.group))
mean = K.mean(var_x, axis=[1, 2, 4], keepdims=True)
std = K.sqrt(K.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon)
var_x = (var_x - mean) / std
var_x = K.reshape(var_x, (batch_size, height, width, channels))
retval = self.gamma * var_x + self.beta
elif self.data_format == 'channels_first':
batch_size, channels, height, width = input_shape
if batch_size is None:
batch_size = -1
if channels < self.group:
raise ValueError('Input channels should be larger than group size' +
'; Received input channels: ' + str(channels) +
'; Group size: ' + str(self.group))
var_x = K.reshape(inputs, (batch_size,
self.group,
channels // self.group,
height,
width))
mean = K.mean(var_x, axis=[2, 3, 4], keepdims=True)
std = K.sqrt(K.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon)
var_x = (var_x - mean) / std
var_x = K.reshape(var_x, (batch_size, channels, height, width))
retval = self.gamma * var_x + self.beta
elif len(input_shape) == 2:
reduction_axes = list(range(0, len(input_shape)))
del reduction_axes[0]
batch_size, _ = input_shape
if batch_size is None:
batch_size = -1
mean = K.mean(inputs, keepdims=True)
std = K.sqrt(K.var(inputs, keepdims=True) + self.epsilon)
var_x = (inputs - mean) / std
retval = self.gamma * var_x + self.beta
return retval
def get_config(self):
config = {'epsilon': self.epsilon,
'axis': self.axis,
'gamma_init': initializers.serialize(self.gamma_init),
'beta_init': initializers.serialize(self.beta_init),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_regularizer': regularizers.serialize(self.gamma_regularizer),
'group': self.group}
base_config = super(GroupNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Update normalizations into Keras custom objects
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and obj.__module__ == __name__:
get_custom_objects().update({name: obj})

View file

@ -117,6 +117,8 @@ class FSThread(threading.Thread):
self._target(*self._args, **self._kwargs) self._target(*self._args, **self._kwargs)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self.err = sys.exc_info() self.err = sys.exc_info()
logger.debug("Error in thread (%s): %s", self._name,
self.err[1].with_traceback(self.err[2]))
finally: finally:
# Avoid a refcycle if the thread is running a function with # Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread. # an argument that has a member that points to the thread.
@ -126,8 +128,8 @@ class FSThread(threading.Thread):
class MultiThread(): class MultiThread():
""" Threading for IO heavy ops """ Threading for IO heavy ops
Catches errors in thread and rethrows to parent """ Catches errors in thread and rethrows to parent """
def __init__(self, target, *args, thread_count=1, **kwargs): def __init__(self, target, *args, thread_count=1, name=None, **kwargs):
self._name = target.__name__ self._name = name if name else target.__name__
logger.debug("Initializing %s: (target: '%s', thread_count: %s)", logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
self.__class__.__name__, self._name, thread_count) self.__class__.__name__, self._name, thread_count)
logger.trace("args: %s, kwargs: %s", args, kwargs) logger.trace("args: %s, kwargs: %s", args, kwargs)
@ -139,6 +141,16 @@ class MultiThread():
self._kwargs = kwargs self._kwargs = kwargs
logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name) logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name)
@property
def has_error(self):
""" Return true if a thread has errored, otherwise false """
return any(thread.err for thread in self._threads)
@property
def errors(self):
""" Return a list of thread errors """
return [thread.err for thread in self._threads]
def start(self): def start(self):
""" Start a thread with the given method and args """ """ Start a thread with the given method and args """
logger.debug("Starting thread(s): '%s'", self._name) logger.debug("Starting thread(s): '%s'", self._name)

View file

@ -9,7 +9,7 @@ import multiprocessing as mp
import sys import sys
import threading import threading
from queue import Empty as QueueEmpty # pylint: disable=unused-import; # noqa from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
from time import sleep from time import sleep
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -37,7 +37,7 @@ class QueueManager():
self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue() self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue()
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def add_queue(self, name, maxsize=0): def add_queue(self, name, maxsize=0, multiprocessing_queue=True):
""" Add a queue to the manager """ Add a queue to the manager
Adds an event "shutdown" to the queue that can be used to indicate Adds an event "shutdown" to the queue that can be used to indicate
@ -46,7 +46,12 @@ class QueueManager():
logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize) logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize)
if name in self.queues.keys(): if name in self.queues.keys():
raise ValueError("Queue '{}' already exists.".format(name)) raise ValueError("Queue '{}' already exists.".format(name))
queue = self.manager.Queue(maxsize=maxsize)
if multiprocessing_queue:
queue = self.manager.Queue(maxsize=maxsize)
else:
queue = Queue(maxsize=maxsize)
setattr(queue, "shutdown", self.shutdown) setattr(queue, "shutdown", self.shutdown)
self.queues[name] = queue self.queues[name] = queue
logger.debug("QueueManager added: (name: '%s')", name) logger.debug("QueueManager added: (name: '%s')", name)

View file

@ -1,106 +1,401 @@
from random import shuffle #!/usr/bin/env python3
import cv2 """ Process training data for model training """
import numpy
import logging
from .multithreading import BackgroundGenerator
from .umeyama import umeyama from hashlib import sha1
from random import shuffle
class TrainingDataGenerator():
def __init__(self, random_transform_args, coverage, scale=5, zoom=1): #TODO thos default should stay in the warp function import cv2
self.random_transform_args = random_transform_args import numpy as np
self.coverage = coverage from scipy.interpolate import griddata
self.scale = scale
self.zoom = zoom from lib.model import masks
from lib.multithreading import MultiThread
def minibatchAB(self, images, batchsize, doShuffle=True): from lib.queue_manager import queue_manager
batch = BackgroundGenerator(self.minibatch(images, batchsize, doShuffle), 1) from lib.umeyama import umeyama
for ep1, warped_img, target_img in batch.iterator():
yield ep1, warped_img, target_img logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# A generator function that yields epoch, batchsize of warped_img and batchsize of target_img
def minibatch(self, data, batchsize, doShuffle=True): class TrainingDataGenerator():
length = len(data) """ Generate training data for models """
assert length >= batchsize, "Number of images is lower than batch-size (Note that too few images may lead to bad training). # images: {}, batch-size: {}".format(length, batchsize) def __init__(self, model_input_size, model_output_size, training_opts):
epoch = i = 0 logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
if doShuffle: "training_opts: %s, landmarks: %s)",
shuffle(data) self.__class__.__name__, model_input_size, model_output_size,
while True: {key: val for key, val in training_opts.items() if key != "landmarks"},
size = batchsize bool(training_opts.get("landmarks", None)))
if i+size > length: self.batchsize = 0
if doShuffle: self.model_input_size = model_input_size
shuffle(data) self.training_opts = training_opts
i = 0 self.mask_function = self.set_mask_function()
epoch+=1 self.landmarks = self.training_opts.get("landmarks", None)
rtn = numpy.float32([self.read_image(img) for img in data[i:i+size]])
i+=size self.processing = ImageManipulation(model_input_size,
yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:] model_output_size,
training_opts.get("coverage_ratio", 0.625))
def color_adjust(self, img): logger.debug("Initialized %s", self.__class__.__name__)
return img / 255.0
def set_mask_function(self):
def read_image(self, fn): """ Set the mask function to use if using mask """
try: mask_type = self.training_opts.get("mask_type", None)
image = self.color_adjust(cv2.imread(fn)) if mask_type:
except TypeError: logger.debug("Mask type: '%s'", mask_type)
raise Exception("Error while reading image", fn) mask_func = getattr(masks, mask_type)
else:
image = cv2.resize(image, (256,256)) mask_func = None
image = self.random_transform( image, **self.random_transform_args ) logger.debug("Mask function: %s", mask_func)
warped_img, target_img = self.random_warp( image, self.coverage, self.scale, self.zoom ) return mask_func
return warped_img, target_img def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False):
""" Keep a queue filled to 8x Batch Size """
def random_transform(self, image, rotation_range, zoom_range, shift_range, random_flip): logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, "
h, w = image.shape[0:2] "is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse)
rotation = numpy.random.uniform(-rotation_range, rotation_range) self.batchsize = batchsize
scale = numpy.random.uniform(1 - zoom_range, 1 + zoom_range) q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side)
tx = numpy.random.uniform(-shift_range, shift_range) * w q_size = batchsize * 8
ty = numpy.random.uniform(-shift_range, shift_range) * h # Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False)
mat[:, 2] += (tx, ty) load_thread = MultiThread(self.load_batches,
result = cv2.warpAffine( images,
image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE) q_name,
if numpy.random.random() < random_flip: side,
result = result[:, ::-1] is_timelapse,
return result do_shuffle)
load_thread.start()
# get pair of random warped images from aligned face image logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name)
def random_warp(self, image, coverage, scale = 5, zoom = 1): return self.minibatch(q_name, load_thread)
assert image.shape == (256, 256, 3)
range_ = numpy.linspace(128 - coverage//2, 128 + coverage//2, 5) def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True):
mapx = numpy.broadcast_to(range_, (5, 5)) """ Load the warped images and target images to queue """
mapy = mapx.T logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', "
"is_timelapse: %s, do_shuffle: %s)",
mapx = mapx + numpy.random.normal(size=(5,5), scale=scale) len(images), q_name, side, is_timelapse, do_shuffle)
mapy = mapy + numpy.random.normal(size=(5,5), scale=scale) epoch = 0
queue = queue_manager.get_queue(q_name)
interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') self.validate_samples(images)
interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') while True:
if do_shuffle:
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) shuffle(images)
for img in images:
src_points = numpy.stack([mapx.ravel(), mapy.ravel() ], axis=-1) logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side)
dst_points = numpy.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2) queue.put(self.process_face(img, side, is_timelapse))
mat = umeyama(src_points, dst_points, True)[0:2] epoch += 1
logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')",
target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom)) epoch, q_name, side)
return warped_image, target_image def validate_samples(self, data):
""" Check the total number of images against batchsize and return
def stack_images(images): the total number of images """
def get_transpose_axes(n): length = len(data)
if n % 2 == 0: msg = ("Number of images is lower than batch-size (Note that too few "
y_axes = list(range(1, n - 1, 2)) "images may lead to bad training). # images: {}, "
x_axes = list(range(0, n - 1, 2)) "batch-size: {}".format(length, self.batchsize))
else: assert length >= self.batchsize, msg
y_axes = list(range(0, n - 1, 2))
x_axes = list(range(1, n - 1, 2)) def minibatch(self, q_name, load_thread):
return y_axes, x_axes, [n - 1] """ A generator function that yields epoch, batchsize of warped_img
and batchsize of target_img from the load queue """
images_shape = numpy.array(images.shape) logger.debug("Launching minibatch generator for queue: '%s'", q_name)
new_axes = get_transpose_axes(len(images_shape)) queue = queue_manager.get_queue(q_name)
new_shape = [numpy.prod(images_shape[x]) for x in new_axes] while True:
return numpy.transpose( if load_thread.has_error:
images, logger.debug("Thread error detected")
axes=numpy.concatenate(new_axes) break
).reshape(new_shape) batch = list()
for _ in range(self.batchsize):
images = queue.get()
for idx, image in enumerate(images):
if len(batch) < idx + 1:
batch.append(list())
batch[idx].append(image)
batch = [np.float32(image) for image in batch]
logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'",
len(batch), [item.shape for item in batch], q_name)
yield batch
logger.debug("Finished minibatch generator for queue: '%s'", q_name)
load_thread.join()
def process_face(self, filename, side, is_timelapse):
""" Load an image and perform transformation and warping """
logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)",
filename, side, is_timelapse)
try:
image = cv2.imread(filename) # pylint: disable=no-member
except TypeError:
raise Exception("Error while reading image", filename)
if self.mask_function or self.training_opts["warp_to_landmarks"]:
src_pts = self.get_landmarks(filename, image, side)
if self.mask_function:
image = self.mask_function(src_pts, image, channels=4)
image = self.processing.color_adjust(image)
if not is_timelapse:
image = self.processing.random_transform(image)
if not self.training_opts["no_flip"]:
image = self.processing.do_random_flip(image)
sample = image.copy()[:, :, :3]
if self.training_opts["warp_to_landmarks"]:
dst_pts = self.get_closest_match(filename, side, src_pts)
processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts)
else:
processed = self.processing.random_warp(image)
processed.insert(0, sample)
logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)",
filename, side, [img.shape for img in processed])
return processed
def get_landmarks(self, filename, image, side):
""" Return the landmarks for this face """
logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side)
lm_key = sha1(image).hexdigest()
try:
src_points = self.landmarks[side][lm_key]
except KeyError:
raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key,
filename))
logger.trace("Returning: (src_points: %s)", src_points)
return src_points
def get_closest_match(self, filename, side, src_points):
""" Return closest matched landmarks from opposite set """
logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'",
filename, src_points)
dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"]
dst_points = list(dst_points.values())
closest = (np.mean(np.square(src_points - dst_points),
axis=(1, 2))).argsort()[:10]
closest = np.random.choice(closest)
dst_points = dst_points[closest]
logger.trace("Returning: (dst_points: %s)", dst_points)
return dst_points
class ImageManipulation():
""" Manipulations to be performed on training images """
def __init__(self, input_size, output_size, coverage_ratio):
""" input_size: Size of the face input into the model
output_size: Size of the face that comes out of the modell
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
"""
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
self.__class__.__name__, input_size, output_size, coverage_ratio)
# Transform args
self.rotation_range = 10 # Range to randomly rotate the image by
self.zoom_range = 0.05 # Range to randomly zoom the image by
self.shift_range = 0.05 # Range to randomly translate the image by
self.random_flip = 0.5 # Chance to flip the image horizontally
# Transform and Warp args
self.input_size = input_size
self.output_size = output_size
# Warp args
self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160
self.scale = 5 # Normal random variable scale
logger.debug("Initialized %s", self.__class__.__name__)
@staticmethod
def color_adjust(img):
""" Color adjust RGB image """
logger.trace("Color adjusting image")
return img.astype('float32') / 255.0
@staticmethod
def separate_mask(image):
""" Return the image and the mask from a 4 channel image """
mask = None
if image.shape[2] == 4:
logger.trace("Image contains mask")
mask = np.expand_dims(image[:, :, -1], axis=2)
image = image[:, :, :3]
else:
logger.trace("Image has no mask")
return image, mask
def get_coverage(self, image):
""" Return coverage value for given image """
coverage = int(image.shape[0] * self.coverage_ratio)
logger.trace("Coverage: %s", coverage)
return coverage
def random_transform(self, image):
""" Randomly transform an image """
logger.trace("Randomly transforming image")
height, width = image.shape[0:2]
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
mat = cv2.getRotationMatrix2D( # pylint: disable=no-member
(width // 2, height // 2), rotation, scale)
mat[:, 2] += (tnx, tny)
result = cv2.warpAffine( # pylint: disable=no-member
image, mat, (width, height),
borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member
logger.trace("Randomly transformed image")
return result
def do_random_flip(self, image):
""" Perform flip on image if random number is within threshold """
logger.trace("Randomly flipping image")
if np.random.random() < self.random_flip:
logger.trace("Flip within threshold. Flipping")
retval = image[:, ::-1]
else:
logger.trace("Flip outside threshold. Not Flipping")
retval = image
logger.trace("Randomly flipped image")
return retval
def random_warp(self, image):
""" get pair of random warped images from aligned face image """
logger.trace("Randomly warping image")
height, width = image.shape[0:2]
coverage = self.get_coverage(image)
assert height == width and height % 2 == 0
range_ = np.linspace(height // 2 - coverage // 2,
height // 2 + coverage // 2,
5, dtype='float32')
mapx = np.broadcast_to(range_, (5, 5)).copy()
mapy = mapx.T
# mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast
pad = int(1.25 * self.input_size)
slices = slice(pad // 10, -pad // 10)
dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4))
interp = np.empty((2, self.input_size, self.input_size), dtype='float32')
####
for i, map_ in enumerate([mapx, mapy]):
map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale)
interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member
warped_image = cv2.remap( # pylint: disable=no-member
image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[dst_slice, dst_slice]
mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2]
target_image = cv2.warpAffine( # pylint: disable=no-member
image, mat, (self.output_size, self.output_size))
logger.trace("Target image shape: %s", target_image.shape)
warped_image, warped_mask = self.separate_mask(warped_image)
target_image, target_mask = self.separate_mask(target_image)
if target_mask is None:
logger.trace("Randomly warped image")
return [warped_image, target_image]
logger.trace("Target mask shape: %s", target_mask.shape)
logger.trace("Randomly warped image and mask")
return [warped_image, target_image, target_mask]
def random_warp_landmarks(self, image, src_points=None, dst_points=None):
""" get warped image, target image and target mask
From DFAKER plugin """
logger.trace("Randomly warping landmarks")
size = image.shape[0]
coverage = self.get_coverage(image)
p_mx = size - 1
p_hf = (size // 2) - 1
edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0),
(p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]
grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)]
source = src_points
destination = (dst_points.copy().astype('float32') +
np.random.normal(size=dst_points.shape, scale=2.0))
destination = destination.astype('uint8')
face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member
[source[17:], destination[17:]], axis=0).astype(int))
source = [(pty, ptx) for ptx, pty in source] + edge_anchors
destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors
indicies_to_remove = set()
for fpl in source, destination:
for idx, (pty, ptx) in enumerate(fpl):
if idx > 17:
break
elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member
(pty, ptx),
False) >= 0:
indicies_to_remove.add(idx)
for idx in sorted(indicies_to_remove, reverse=True):
source.pop(idx)
destination.pop(idx)
grid_z = griddata(destination, source, (grid_x, grid_y), method="linear")
map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size)
map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size)
map_x_32 = map_x.astype('float32')
map_y_32 = map_y.astype('float32')
warped_image = cv2.remap(image, # pylint: disable=no-member
map_x_32,
map_y_32,
cv2.INTER_LINEAR, # pylint: disable=no-member
cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
target_image = image
# TODO Make sure this replacement is correct
slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2)
# slices = slice(size // 32, size - size // 32) # 8px on a 256px image
warped_image = cv2.resize( # pylint: disable=no-member
warped_image[slices, slices, :], (self.input_size, self.input_size),
cv2.INTER_AREA) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)
target_image = cv2.resize( # pylint: disable=no-member
target_image[slices, slices, :], (self.output_size, self.output_size),
cv2.INTER_AREA) # pylint: disable=no-member
logger.trace("Target image shape: %s", target_image.shape)
warped_image, warped_mask = self.separate_mask(warped_image)
target_image, target_mask = self.separate_mask(target_image)
if target_mask is None:
logger.trace("Randomly warped image")
return [warped_image, target_image]
logger.trace("Target mask shape: %s", target_mask.shape)
logger.trace("Randomly warped image and mask")
return [warped_image, target_image, target_mask]
def stack_images(images):
""" Stack images """
logger.debug("Stack images")
def get_transpose_axes(num):
if num % 2 == 0:
logger.debug("Even number of images to stack")
y_axes = list(range(1, num - 1, 2))
x_axes = list(range(0, num - 1, 2))
else:
logger.debug("Odd number of images to stack")
y_axes = list(range(0, num - 1, 2))
x_axes = list(range(1, num - 1, 2))
return y_axes, x_axes, [num - 1]
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
logger.debug("Stacked images")
return np.transpose(
images,
axes=np.concatenate(new_axes)
).reshape(new_shape)

View file

@ -12,8 +12,27 @@
import numpy as np import numpy as np
MEAN_FACE_X = np.array([
0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483,
0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127,
0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122,
0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132,
0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127,
.551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532,
0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364,
0.490127, 0.42689])
def umeyama(src, dst, estimate_scale): MEAN_FACE_Y = np.array([
0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891,
0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625,
0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758,
0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758,
0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578,
0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192,
0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182,
0.831803, 0.824182])
def umeyama(src, estimate_scale, dst=None):
"""Estimate N-D similarity transformation with or without scaling. """Estimate N-D similarity transformation with or without scaling.
Parameters Parameters
---------- ----------
@ -33,6 +52,8 @@ def umeyama(src, dst, estimate_scale):
.. [1] "Least-squares estimation of transformation parameters between two .. [1] "Least-squares estimation of transformation parameters between two
point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573
""" """
if dst is None:
dst = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1)
num = src.shape[0] num = src.shape[0]
dim = src.shape[1] dim = src.shape[1]

View file

@ -8,7 +8,6 @@ import warnings
from hashlib import sha1 from hashlib import sha1
from pathlib import Path from pathlib import Path
from re import finditer from re import finditer
from time import time
import cv2 import cv2
import numpy as np import numpy as np
@ -16,7 +15,6 @@ import numpy as np
import dlib import dlib
from lib.faces_detect import DetectedFace from lib.faces_detect import DetectedFace
from lib.training_data import TrainingDataGenerator
from lib.logger import get_loglevel from lib.logger import get_loglevel
@ -62,7 +60,7 @@ def get_image_paths(directory):
def hash_image_file(filename): def hash_image_file(filename):
""" Return the filename with it's sha1 hash """ """ Return an image file's sha1 hash """
img = cv2.imread(filename) # pylint: disable=no-member img = cv2.imread(filename) # pylint: disable=no-member
img_hash = sha1(img).hexdigest() img_hash = sha1(img).hexdigest()
logger.trace("filename: '%s', hash: %s", filename, img_hash) logger.trace("filename: '%s', hash: %s", filename, img_hash)
@ -107,33 +105,12 @@ def set_system_verbosity(loglevel):
logger.debug("System Verbosity level: %s", loglevel) logger.debug("System Verbosity level: %s", loglevel)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel
if loglevel != '0': if loglevel != '0':
for warncat in (FutureWarning, DeprecationWarning): for warncat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.simplefilter(action='ignore', category=warncat) warnings.simplefilter(action='ignore', category=warncat)
def add_alpha_channel(image, intensity=100):
""" Add an alpha channel to an image
intensity: The opacity of the alpha channel between 0 and 100
100 = transparent,
0 = solid """
logger.trace("Adding alpha channel: intensity: %s", intensity)
assert 0 <= intensity <= 100, "Invalid intensity supplied"
intensity = (255.0 / 100.0) * intensity
d_type = image.dtype
image = image.astype("float32")
ch_b, ch_g, ch_r = cv2.split(image) # pylint: disable=no-member
ch_a = np.ones(ch_b.shape, dtype="float32") * intensity
image_bgra = cv2.merge( # pylint: disable=no-member
(ch_b, ch_g, ch_r, ch_a))
logger.trace("Added alpha channel", intensity)
return image_bgra.astype(d_type)
def rotate_landmarks(face, rotation_matrix): def rotate_landmarks(face, rotation_matrix):
# pylint: disable=c-extension-no-member
""" Rotate the landmarks and bounding box for faces """ Rotate the landmarks and bounding box for faces
found in rotated images. found in rotated images.
Pass in a DetectedFace object, Alignments dict or DLib rectangle""" Pass in a DetectedFace object, Alignments dict or DLib rectangle"""
@ -223,80 +200,6 @@ def camel_case_split(identifier):
return [m.group(0) for m in matches] return [m.group(0) for m in matches]
class Timelapse:
""" Time lapse function for training """
@classmethod
def create_timelapse(cls, input_dir_a, input_dir_b, output_dir, trainer):
""" Create the time lapse """
if input_dir_a is None and input_dir_b is None and output_dir is None:
return None
if input_dir_a is None or input_dir_b is None:
raise ValueError("To enable the timelapse, you have to supply "
"all the parameters (--timelapse-input-A and "
"--timelapse-input-B).")
if output_dir is None:
output_dir = get_folder(os.path.join(trainer.model.model_dir,
"timelapse"))
return Timelapse(input_dir_a, input_dir_b, output_dir, trainer)
def __init__(self, input_dir_a, input_dir_b, output, trainer):
self.output_dir = output
self.trainer = trainer
if not os.path.isdir(self.output_dir):
logger.error("'%s' does not exist", self.output_dir)
exit(1)
self.files_a = self.read_input_images(input_dir_a)
self.files_b = self.read_input_images(input_dir_b)
btchsz = min(len(self.files_a), len(self.files_b))
self.images_a = self.get_image_data(self.files_a, btchsz)
self.images_b = self.get_image_data(self.files_b, btchsz)
@staticmethod
def read_input_images(input_dir):
""" Get the image paths """
if not os.path.isdir(input_dir):
logger.error("'%s' does not exist", input_dir)
exit(1)
if not os.listdir(input_dir):
logger.error("'%s' contains no images", input_dir)
exit(1)
return get_image_paths(input_dir)
def get_image_data(self, input_images, batch_size):
""" Get training images """
random_transform_args = {
'rotation_range': 0,
'zoom_range': 0,
'shift_range': 0,
'random_flip': 0
}
zoom = 1
if hasattr(self.trainer.model, 'IMAGE_SHAPE'):
zoom = self.trainer.model.IMAGE_SHAPE[0] // 64
generator = TrainingDataGenerator(random_transform_args, 160, zoom)
batch = generator.minibatchAB(input_images, batch_size,
doShuffle=False)
return next(batch)[2]
def work(self):
""" Write out timelapse image """
image = self.trainer.show_sample(self.images_a, self.images_b)
cv2.imwrite(os.path.join(self.output_dir, # pylint: disable=no-member
str(int(time())) + ".png"), image)
def safe_shutdown(): def safe_shutdown():
""" Close queues, threads and processes in event of crash """ """ Close queues, threads and processes in event of crash """
logger.debug("Safely shutting down") logger.debug("Safely shutting down")

View file

@ -1,114 +0,0 @@
#!/usr/bin/env python3
""" Adjust converter for faceswap.py
Based on the original https://www.reddit.com/r/deepfakes/ code sample
Adjust code made by https://github.com/yangchen8710 """
import cv2
import numpy as np
from lib.utils import add_alpha_channel
class Convert():
""" Adjust Converter """
def __init__(self, encoder, smooth_mask=True, avg_color_adjust=True,
draw_transparent=False, **kwargs):
self.encoder = encoder
self.use_smooth_mask = smooth_mask
self.use_avg_color_adjust = avg_color_adjust
self.draw_transparent = draw_transparent
def patch_image(self, frame, detected_face, size):
""" Patch swapped face onto original image """
# pylint: disable=no-member
# assert image.shape == (256, 256, 3)
padding = 48
face_size = 256
detected_face.load_aligned(frame, face_size, padding,
align_eyes=False)
src_face = detected_face.aligned_face
crop = slice(padding, face_size - padding)
process_face = src_face[crop, crop]
old_face = process_face.copy()
process_face = cv2.resize(process_face,
(size, size),
interpolation=cv2.INTER_AREA)
process_face = np.expand_dims(process_face, 0)
new_face = self.encoder(process_face / 255.0)[0]
new_face = np.clip(new_face * 255, 0, 255).astype(src_face.dtype)
new_face = cv2.resize(
new_face,
(face_size - padding * 2, face_size - padding * 2),
interpolation=cv2.INTER_CUBIC)
if self.use_avg_color_adjust:
self.adjust_avg_color(old_face, new_face)
if self.use_smooth_mask:
self.smooth_mask(old_face, new_face)
new_face = self.superpose(src_face, new_face, crop)
new_image = frame.copy()
if self.draw_transparent:
new_image, new_face = self.convert_transparent(new_image,
new_face)
cv2.warpAffine(
new_face,
detected_face.adjusted_matrix,
(detected_face.frame_dims[1], detected_face.frame_dims[0]),
new_image,
flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC,
borderMode=cv2.BORDER_TRANSPARENT)
return new_image
@staticmethod
def adjust_avg_color(old_face, new_face):
""" Perform average color adjustment """
for i in range(new_face.shape[-1]):
old_avg = old_face[:, :, i].mean()
new_avg = new_face[:, :, i].mean()
diff_int = (int)(old_avg - new_avg)
for int_h in range(new_face.shape[0]):
for int_w in range(new_face.shape[1]):
temp = (new_face[int_h, int_w, i] + diff_int)
if temp < 0:
new_face[int_h, int_w, i] = 0
elif temp > 255:
new_face[int_h, int_w, i] = 255
else:
new_face[int_h, int_w, i] = temp
@staticmethod
def smooth_mask(old_face, new_face):
""" Smooth the mask """
width, height, _ = new_face.shape
crop = slice(0, width)
mask = np.zeros_like(new_face)
mask[height // 15:-height // 15, width // 15:-width // 15, :] = 255
mask = cv2.GaussianBlur(mask, # pylint: disable=no-member
(15, 15),
10)
new_face[crop, crop] = (mask / 255 * new_face +
(1 - mask / 255) * old_face)
@staticmethod
def superpose(src_face, new_face, crop):
""" Crop Face """
new_image = src_face.copy()
new_image[crop, crop] = new_face
return new_image
@staticmethod
def convert_transparent(image, new_face):
""" Add alpha channels to images and change to
transparent background """
height, width = image.shape[:2]
image = np.zeros((height, width, 4), dtype=np.uint8)
new_face = add_alpha_channel(new_face, 100)
return image, new_face

View file

@ -1,230 +0,0 @@
#!/usr/bin/env python3
""" Masked converter for faceswap.py
Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955
found on https://www.reddit.com/r/deepfakes/ """
import logging
import cv2
import numpy
from lib.aligner import get_align_mat
from lib.utils import add_alpha_channel
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Convert():
def __init__(self, encoder, trainer,
blur_size=2, seamless_clone=False, mask_type="facehullandrect",
erosion_kernel_size=None, match_histogram=False, sharpen_image=None,
draw_transparent=False, **kwargs):
self.encoder = encoder
self.trainer = trainer
self.erosion_kernel = None
self.erosion_kernel_size = erosion_kernel_size
if erosion_kernel_size is not None:
if erosion_kernel_size > 0:
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(erosion_kernel_size,
erosion_kernel_size))
elif erosion_kernel_size < 0:
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(abs(erosion_kernel_size),
abs(erosion_kernel_size)))
self.blur_size = blur_size
self.seamless_clone = seamless_clone
self.sharpen_image = sharpen_image
self.match_histogram = match_histogram
self.mask_type = mask_type.lower() # Choose in 'FaceHullAndRect', 'FaceHull', 'Rect'
self.draw_transparent = draw_transparent
def patch_image(self, image, face_detected, size):
image_size = image.shape[1], image.shape[0]
mat = numpy.array(get_align_mat(face_detected,
size,
should_align_eyes=False)).reshape(2, 3)
if "GAN" not in self.trainer:
mat = mat * size
else:
padding = int(48/256*size)
mat = mat * (size - 2 * padding)
mat[:, 2] += padding
new_face = self.get_new_face(image, mat, size)
image_mask = self.get_image_mask(image,
new_face,
face_detected.landmarks_as_xy,
mat,
image_size)
return self.apply_new_face(image, new_face, image_mask, mat, image_size, size)
@staticmethod
def convert_transparent(image, new_face, image_mask, image_size):
""" Add alpha channels to images and change to
transparent background """
image = numpy.zeros((image_size[1], image_size[0], 4),
dtype=numpy.uint8)
image_mask = add_alpha_channel(image_mask, 100)
new_face = add_alpha_channel(new_face, 100)
return image, new_face, image_mask
def apply_new_face(self, image, new_face, image_mask, mat, image_size, size):
if self.draw_transparent:
image, new_face, image_mask = self.convert_transparent(image,
new_face,
image_mask,
image_size)
self.seamless_clone = False # Alpha channel not supported in seamless
base_image = numpy.copy(image)
new_image = numpy.copy(image)
cv2.warpAffine(new_face,
mat,
image_size,
new_image,
cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC,
cv2.BORDER_TRANSPARENT)
if self.sharpen_image == "bsharpen":
# Sharpening using filter2D
kernel = numpy.ones((3, 3)) * (-1)
kernel[1, 1] = 9
new_image = cv2.filter2D(new_image, -1, kernel)
elif self.sharpen_image == "gsharpen":
# Sharpening using Weighted Method
gaussain_blur = cv2.GaussianBlur(new_image, (0, 0), 3.0)
new_image = cv2.addWeighted(
new_image, 1.5, gaussain_blur, -0.5, 0, new_image)
outimage = None
if self.seamless_clone:
unitMask = numpy.clip(image_mask * 365, 0, 255).astype(numpy.uint8)
logger.info(unitMask.shape)
logger.info(new_image.shape)
logger.info(base_image.shape)
maxregion = numpy.argwhere(unitMask == 255)
if maxregion.size > 0:
miny, minx = maxregion.min(axis=0)[:2]
maxy, maxx = maxregion.max(axis=0)[:2]
lenx = maxx - minx
leny = maxy - miny
masky = int(minx + (lenx // 2))
maskx = int(miny + (leny // 2))
outimage = cv2.seamlessClone(new_image.astype(numpy.uint8),
base_image.astype(numpy.uint8),
unitMask,
(masky, maskx),
cv2.NORMAL_CLONE)
return outimage
foreground = cv2.multiply(image_mask, new_image.astype(float))
background = cv2.multiply(1.0 - image_mask, base_image.astype(float))
outimage = cv2.add(foreground, background)
return outimage
def hist_match(self, source, template, mask=None):
# Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
masked_source = source
masked_template = template
if mask is not None:
masked_source = source * mask
masked_template = template * mask
oldshape = source.shape
source = source.ravel()
template = template.ravel()
masked_source = masked_source.ravel()
masked_template = masked_template.ravel()
s_values, bin_idx, s_counts = numpy.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = numpy.unique(template, return_counts=True)
ms_values, mbin_idx, ms_counts = numpy.unique(source, return_inverse=True,
return_counts=True)
mt_values, mt_counts = numpy.unique(template, return_counts=True)
s_quantiles = numpy.cumsum(s_counts).astype(numpy.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = numpy.cumsum(t_counts).astype(numpy.float64)
t_quantiles /= t_quantiles[-1]
interp_t_values = numpy.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(self, src_im, tar_im, mask):
matched_R = self.hist_match(src_im[:, :, 0], tar_im[:, :, 0], mask)
matched_G = self.hist_match(src_im[:, :, 1], tar_im[:, :, 1], mask)
matched_B = self.hist_match(src_im[:, :, 2], tar_im[:, :, 2], mask)
matched = numpy.stack((matched_R, matched_G, matched_B), axis=2).astype(src_im.dtype)
return matched
def get_new_face(self, image, mat, size):
face = cv2.warpAffine(image, mat, (size, size))
face = numpy.expand_dims(face, 0)
face_clipped = numpy.clip(face[0], 0, 255).astype(image.dtype)
new_face = None
mask = None
if "GAN" not in self.trainer:
normalized_face = face / 255.0
new_face = self.encoder(normalized_face)[0]
new_face = numpy.clip(new_face * 255, 0, 255).astype(image.dtype)
else:
normalized_face = face / 255.0 * 2 - 1
fake_output = self.encoder(normalized_face)
if "128" in self.trainer: # TODO: Another hack to switch between 64 and 128
fake_output = fake_output[0]
mask = fake_output[:, :, :, :1]
new_face = fake_output[:, :, :, 1:]
new_face = mask * new_face + (1 - mask) * normalized_face
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype(image.dtype)
if self.match_histogram:
new_face = self.color_hist_match(new_face, face_clipped, mask)
return new_face
def get_image_mask(self, image, new_face, landmarks, mat, image_size):
face_mask = numpy.zeros(image.shape, dtype=float)
if 'rect' in self.mask_type:
face_src = numpy.ones(new_face.shape, dtype=float)
cv2.warpAffine(face_src,
mat,
image_size,
face_mask,
cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT)
hull_mask = numpy.zeros(image.shape, dtype=float)
if 'hull' in self.mask_type:
hull = cv2.convexHull(
numpy.array(landmarks).reshape((-1, 2)).astype(int)).flatten().reshape((-1, 2))
cv2.fillConvexPoly(hull_mask, hull, (1, 1, 1))
if self.mask_type == 'rect':
image_mask = face_mask
elif self.mask_type == 'facehull':
image_mask = hull_mask
else:
image_mask = ((face_mask*hull_mask))
if self.erosion_kernel is not None:
if self.erosion_kernel_size > 0:
image_mask = cv2.erode(image_mask, self.erosion_kernel, iterations=1)
elif self.erosion_kernel_size < 0:
dilation_kernel = abs(self.erosion_kernel)
image_mask = cv2.dilate(image_mask, dilation_kernel, iterations=1)
if self.blur_size != 0:
image_mask = cv2.blur(image_mask, (self.blur_size, self.blur_size))
return image_mask

358
plugins/convert/masked.py Normal file
View file

@ -0,0 +1,358 @@
#!/usr/bin/env python3
""" Masked converter for faceswap.py
Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955
found on https://www.reddit.com/r/deepfakes/ """
import logging
import cv2
import numpy as np
from lib.model.masks import dfl_full
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Convert():
""" Swap a source face with a target """
def __init__(self, encoder, model, arguments):
logger.debug("Initializing %s: (encoder: '%s', model: %s, arguments: %s",
self.__class__.__name__, encoder, model, arguments)
self.encoder = encoder
self.args = arguments
self.input_size = model.input_shape[0]
self.training_size = model.state.training_size
self.training_coverage_ratio = model.training_opts["coverage_ratio"]
self.input_mask_shape = model.state.mask_shapes[0] if model.state.mask_shapes else None
self.crop = None
self.mask = None
logger.debug("Initialized %s", self.__class__.__name__)
def patch_image(self, image, detected_face):
""" Patch the image """
logger.trace("Patching image")
image = image.astype('float32')
image_size = (image.shape[1], image.shape[0])
coverage = int(self.training_coverage_ratio * self.training_size)
padding = (self.training_size - coverage) // 2
logger.trace("coverage: %s, padding: %s", coverage, padding)
self.crop = slice(padding, self.training_size - padding)
if not self.mask: # Init the mask on first image
self.mask = Mask(self.args.mask_type, self.training_size, padding, self.crop)
detected_face.load_aligned(image, size=self.training_size, align_eyes=False)
new_image = self.get_new_image(image, detected_face, coverage, image_size)
image_mask = self.get_image_mask(detected_face, image_size)
patched_face = self.apply_fixes(image,
new_image,
image_mask,
image_size)
logger.trace("Patched image")
return patched_face
def get_new_image(self, image, detected_face, coverage, image_size):
""" Get the new face from the predictor """
logger.trace("coverage: %s", coverage)
src_face = detected_face.aligned_face
coverage_face = src_face[self.crop, self.crop]
coverage_face = cv2.resize(coverage_face, # pylint: disable=no-member
(self.input_size, self.input_size),
interpolation=cv2.INTER_AREA) # pylint: disable=no-member
coverage_face = np.expand_dims(coverage_face, 0)
np.clip(coverage_face / 255.0, 0.0, 1.0, out=coverage_face)
if self.input_mask_shape:
mask = np.zeros(self.input_mask_shape, np.float32)
mask = np.expand_dims(mask, 0)
feed = [coverage_face, mask]
else:
feed = [coverage_face]
logger.trace("Input shapes: %s", [item.shape for item in feed])
new_face = self.encoder(feed)[0]
new_face = new_face.squeeze()
logger.trace("Output shape: %s", new_face.shape)
new_face = cv2.resize(new_face, # pylint: disable=no-member
(coverage, coverage),
interpolation=cv2.INTER_CUBIC) # pylint: disable=no-member
np.clip(new_face * 255.0, 0.0, 255.0, out=new_face)
src_face[self.crop, self.crop] = new_face
background = image.copy()
interpolator = detected_face.adjusted_interpolators[1]
new_image = cv2.warpAffine( # pylint: disable=no-member
src_face,
detected_face.adjusted_matrix,
image_size,
background,
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
borderMode=cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
return new_image
def get_image_mask(self, detected_face, image_size):
""" Get the image mask """
mask = self.mask.get_mask(detected_face, image_size)
if self.args.erosion_size != 0:
kwargs = {'src': mask,
'kernel': self.set_erosion_kernel(mask),
'iterations': 1}
if self.args.erosion_size > 0:
mask = cv2.erode(**kwargs) # pylint: disable=no-member
else:
mask = cv2.dilate(**kwargs) # pylint: disable=no-member
if self.args.blur_size != 0:
blur_size = self.set_blur_size(mask)
mask = cv2.blur(mask, (blur_size, blur_size)) # pylint: disable=no-member
return np.clip(mask, 0.0, 1.0, out=mask)
def set_erosion_kernel(self, mask):
""" Set the erosion kernel """
erosion_ratio = self.args.erosion_size / 100
mask_radius = np.sqrt(np.sum(mask)) / 2
percent_erode = max(1, int(abs(erosion_ratio * mask_radius)))
erosion_kernel = cv2.getStructuringElement( # pylint: disable=no-member
cv2.MORPH_ELLIPSE, # pylint: disable=no-member
(percent_erode, percent_erode))
logger.trace("erosion_kernel shape: %s", erosion_kernel.shape)
return erosion_kernel
def set_blur_size(self, mask):
""" Set the blur size to absolute or percentage """
blur_ratio = self.args.blur_size / 100
mask_radius = np.sqrt(np.sum(mask)) / 2
blur_size = int(max(1, blur_ratio * mask_radius))
logger.trace("blur_size: %s", blur_size)
return blur_size
def apply_fixes(self, frame, new_image, image_mask, image_size):
""" Apply fixes """
masked = new_image # * image_mask
if self.args.draw_transparent:
alpha = np.full((image_size[1], image_size[0], 1), 255.0, dtype='float32')
new_image = np.concatenate(new_image, alpha, axis=2)
image_mask = np.concatenate(image_mask, alpha, axis=2)
frame = np.concatenate(frame, alpha, axis=2)
if self.args.sharpen_image is not None:
np.clip(masked, 0.0, 255.0, out=masked)
if self.args.sharpen_image == "box_filter":
kernel = np.ones((3, 3)) * (-1)
kernel[1, 1] = 9
masked = cv2.filter2D(masked, -1, kernel) # pylint: disable=no-member
elif self.args.sharpen_image == "gaussian_filter":
blur = cv2.GaussianBlur(masked, (0, 0), 3.0) # pylint: disable=no-member
masked = cv2.addWeighted(masked, # pylint: disable=no-member
1.5,
blur,
-0.5,
0,
masked)
if self.args.avg_color_adjust:
for _ in [0, 1]:
np.clip(masked, 0.0, 255.0, out=masked)
diff = frame - masked
avg_diff = np.sum(diff * image_mask, axis=(0, 1))
adjustment = avg_diff / np.sum(image_mask, axis=(0, 1))
masked = masked + adjustment
if self.args.match_histogram:
np.clip(masked, 0.0, 255.0, out=masked)
masked = self.color_hist_match(masked, frame, image_mask)
if self.args.seamless_clone and not self.args.draw_transparent:
h, w, _ = frame.shape
h = h // 2
w = w // 2
y_indices, x_indices, _ = np.nonzero(image_mask)
y_crop = slice(np.min(y_indices), np.max(y_indices))
x_crop = slice(np.min(x_indices), np.max(x_indices))
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2) + h)
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2) + w)
'''
# test with average of centroid rather than the h /2 , w/2 center
y_center = int(np.rint(np.average(y_indices) + h)
x_center = int(np.rint(np.average(x_indices) + w)
'''
insertion = np.rint(masked[y_crop, x_crop, :]).astype('uint8')
insertion_mask = image_mask[y_crop, x_crop, :]
insertion_mask[insertion_mask != 0] = 255
insertion_mask = insertion_mask.astype('uint8')
prior = np.pad(frame, ((h, h), (w, w), (0, 0)), 'constant').astype('uint8')
blended = cv2.seamlessClone(insertion, # pylint: disable=no-member
prior,
insertion_mask,
(x_center, y_center),
cv2.NORMAL_CLONE) # pylint: disable=no-member
blended = blended[h:-h, w:-w, :]
else:
foreground = masked * image_mask
background = frame * (1.0 - image_mask)
blended = foreground + background
np.clip(blended, 0.0, 255.0, out=blended)
return np.rint(blended).astype('uint8')
def color_hist_match(self, new, frame, image_mask):
for channel in [0, 1, 2]:
new[:, :, channel] = self.hist_match(new[:, :, channel],
frame[:, :, channel],
image_mask[:, :, channel])
# source = np.stack([self.hist_match(source[:,:,c], target[:,:,c],image_mask[:,:,c])
# for c in [0,1,2]],
# axis=2)
return new
def hist_match(self, new, frame, image_mask):
mask_indices = np.nonzero(image_mask)
if len(mask_indices[0]) == 0:
return new
m_new = new[mask_indices].ravel()
m_frame = frame[mask_indices].ravel()
s_values, bin_idx, s_counts = np.unique(m_new, return_inverse=True, return_counts=True)
t_values, t_counts = np.unique(m_frame, return_counts=True)
s_quants = np.cumsum(s_counts, dtype='float32')
t_quants = np.cumsum(t_counts, dtype='float32')
s_quants /= s_quants[-1] # cdf
t_quants /= t_quants[-1] # cdf
interp_s_values = np.interp(s_quants, t_quants, t_values)
new.put(mask_indices, interp_s_values[bin_idx])
'''
bins = np.arange(256)
template_CDF, _ = np.histogram(m_frame, bins=bins, density=True)
flat_new_image = np.interp(m_source.ravel(), bins[:-1], template_CDF) * 255.0
return flat_new_image.reshape(m_source.shape) * 255.0
'''
return new
class Mask():
""" Return the requested mask """
def __init__(self, mask_type, training_size, padding, crop):
""" Set requested mask """
logger.debug("Initializing %s: (mask_type: '%s', training_size: %s, padding: %s)",
self.__class__.__name__, mask_type, training_size, padding)
self.training_size = training_size
self.padding = padding
self.mask_type = mask_type
self.crop = crop
logger.debug("Initialized %s", self.__class__.__name__)
def get_mask(self, detected_face, image_size):
""" Return a face mask """
kwargs = {"matrix": detected_face.adjusted_matrix,
"interpolators": detected_face.adjusted_interpolators,
"landmarks": detected_face.landmarks_as_xy,
"image_size": image_size}
logger.trace("kwargs: %s", kwargs)
mask = getattr(self, self.mask_type)(**kwargs)
mask = self.finalize_mask(mask)
logger.trace("mask shape: %s", mask.shape)
return mask
def cnn(self, **kwargs):
""" CNN Mask """
# Insert FCN-VGG16 segmentation mask model here
logger.info("cnn not yet implemented, using facehull instead")
return self.facehull(**kwargs)
def smoothed(self, **kwargs):
""" Smoothed Mask """
logger.trace("Getting mask")
interpolator = kwargs["interpolators"][1]
ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32')
# area = self.padding + (self.training_size - 2 * self.padding) // 15
# central_core = slice(area, -area)
ones[self.crop, self.crop] = 1.0
ones = cv2.GaussianBlur(ones, (25, 25), 10) # pylint: disable=no-member
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
cv2.warpAffine(ones, # pylint: disable=no-member
kwargs["matrix"],
kwargs["image_size"],
mask,
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member
borderValue=0.0)
return mask
def rect(self, **kwargs):
""" Rect Mask """
logger.trace("Getting mask")
interpolator = kwargs["interpolators"][1]
ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32')
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
# central_core = slice(self.padding, -self.padding)
ones[self.crop, self.crop] = 1.0
cv2.warpAffine(ones, # pylint: disable=no-member
kwargs["matrix"],
kwargs["image_size"],
mask,
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member
borderValue=0.0)
return mask
def dfl(self, **kwargs):
""" DFaker Mask """
logger.trace("Getting mask")
dummy = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
mask = dfl_full(kwargs["landmarks"], dummy, channels=3)
return mask
def facehull(self, **kwargs):
""" Facehull Mask """
logger.trace("Getting mask")
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
hull = cv2.convexHull( # pylint: disable=no-member
np.array(kwargs["landmarks"]).reshape((-1, 2)))
cv2.fillConvexPoly(mask, # pylint: disable=no-member
hull,
(1.0, 1.0, 1.0),
lineType=cv2.LINE_AA) # pylint: disable=no-member
return mask
def facehull_rect(self, **kwargs):
""" Facehull Rect Mask """
logger.trace("Getting mask")
mask = self.rect(**kwargs)
hull_mask = self.facehull(**kwargs)
mask *= hull_mask
return mask
def ellipse(self, **kwargs):
""" Ellipse Mask """
logger.trace("Getting mask")
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
ell = cv2.fitEllipse( # pylint: disable=no-member
np.array(kwargs["landmarks"]).reshape((-1, 2)))
cv2.ellipse(mask, # pylint: disable=no-member
box=ell,
color=(1.0, 1.0, 1.0),
thickness=-1)
return mask
@staticmethod
def finalize_mask(mask):
""" Finalize the mask """
logger.trace("Finalizing mask")
np.nan_to_num(mask, copy=False)
np.clip(mask, 0.0, 1.0, out=mask)
return mask

View file

@ -0,0 +1,48 @@
#!/usr/bin/env python3
""" Default configurations for extract """
import logging
from lib.config import FaceswapConfig
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Config(FaceswapConfig):
""" Config File for Models """
def set_defaults(self):
""" Set the default values for config """
logger.debug("Setting defaults")
# << GLOBAL OPTIONS >> #
# section = "global"
# self.add_section(title=section,
# info="Options that apply to all models")
# << MTCNN DETECTOR OPTIONS >> #
section = "detect.mtcnn"
self.add_section(title=section,
info="MTCNN Detector options")
self.add_item(
section=section, title="minsize", datatype=int, default=20, rounding=10,
min_max=(20, 1000),
info="The minimum size of a face (in pixels) to be accepted as a positive match.\n"
"Lower values use significantly more VRAM and will detect more false positives")
self.add_item(
section=section, title="threshold_1", datatype=float, default=0.6, rounding=2,
min_max=(0.1, 0.9),
info="First stage threshold for face detection. This stage obtains face candidates")
self.add_item(
section=section, title="threshold_2", datatype=float, default=0.7, rounding=2,
min_max=(0.1, 0.9),
info="Second stage threshold for face detection. This stage refines face candidates")
self.add_item(
section=section, title="threshold_3", datatype=float, default=0.7, rounding=2,
min_max=(0.1, 0.9),
info="Third stage threshold for face detection. This stage further refines face "
"candidates")
self.add_item(
section=section, title="scalefactor", datatype=float, default=0.709, rounding=3,
min_max=(0.1, 0.9),
info="The scale factor for the image pyramid")

View file

@ -22,14 +22,21 @@ from math import sqrt
from lib.gpu_stats import GPUStats from lib.gpu_stats import GPUStats
from lib.utils import rotate_landmarks from lib.utils import rotate_landmarks
from plugins.extract._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name):
""" Return the config for the requested model """
return Config(plugin_name).config_dict
class Detector(): class Detector():
""" Detector object """ """ Detector object """
def __init__(self, loglevel, rotation=None): def __init__(self, loglevel, rotation=None):
logger.debug("Initializing %s: (rotation: %s)", self.__class__.__name__, rotation) logger.debug("Initializing %s: (rotation: %s)", self.__class__.__name__, rotation)
self.config = get_config(".".join(self.__module__.split(".")[-2:]))
self.loglevel = loglevel self.loglevel = loglevel
self.cachepath = os.path.join(os.path.dirname(__file__), ".cache") self.cachepath = os.path.join(os.path.dirname(__file__), ".cache")
self.rotation = self.get_rotation_angles(rotation) self.rotation = self.get_rotation_angles(rotation)
@ -107,6 +114,7 @@ class Detector():
logger.exception("Traceback:") logger.exception("Traceback:")
tb_buffer = StringIO() tb_buffer = StringIO()
traceback.print_exc(file=tb_buffer) traceback.print_exc(file=tb_buffer)
logger.trace(tb_buffer.getvalue())
exception = {"exception": (os.getpid(), tb_buffer)} exception = {"exception": (os.getpid(), tb_buffer)}
self.queues["out"].put(exception) self.queues["out"].put(exception)
exit(1) exit(1)

View file

@ -111,7 +111,7 @@ class Detect(Detector):
def detect_batch(self, detect_images, disable_message=False): def detect_batch(self, detect_images, disable_message=False):
""" Pass the batch through detector for consistently sized images """ Pass the batch through detector for consistently sized images
or each image seperately for inconsitently sized images """ or each image separately for inconsitently sized images """
logger.trace("Detecting Batch") logger.trace("Detecting Batch")
can_batch = self.check_batch_dims(detect_images) can_batch = self.check_batch_dims(detect_images)
if can_batch: if can_batch:

View file

@ -30,21 +30,24 @@ class Detect(Detector):
""" MTCNN detector for face recognition """ """ MTCNN detector for face recognition """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.kwargs = None self.kwargs = self.validate_kwargs()
self.name = "mtcnn" self.name = "mtcnn"
self.target = 2073600 # Uses approx 1.30 GB of VRAM self.target = 2073600 # Uses approx 1.30 GB of VRAM
self.vram = 1408 self.vram = 1408
@staticmethod def validate_kwargs(self):
def validate_kwargs(kwargs): """ Validate that config options are correct. If not reset to default """
""" Validate that cli kwargs are correct. If not reset to default """
valid = True valid = True
if kwargs['minsize'] < 10: threshold = [self.config["threshold_1"],
self.config["threshold_2"],
self.config["threshold_3"]]
kwargs = {"minsize": self.config["minsize"],
"threshold": threshold,
"factor": self.config["scalefactor"]}
if kwargs["minsize"] < 10:
valid = False valid = False
elif len(kwargs['threshold']) != 3: elif not all(0.0 < threshold <= 1.0 for threshold in kwargs['threshold']):
valid = False
elif not all(0.0 < threshold < 1.0
for threshold in kwargs['threshold']):
valid = False valid = False
elif not 0.0 < kwargs['factor'] < 1.0: elif not 0.0 < kwargs['factor'] < 1.0:
valid = False valid = False
@ -53,7 +56,7 @@ class Detect(Detector):
kwargs = {"minsize": 20, # minimum size of face kwargs = {"minsize": 20, # minimum size of face
"threshold": [0.6, 0.7, 0.7], # three steps threshold "threshold": [0.6, 0.7, 0.7], # three steps threshold
"factor": 0.709} # scale factor "factor": 0.709} # scale factor
logger.warning("Invalid MTCNN arguments received. Running with defaults") logger.warning("Invalid MTCNN options in config. Running with defaults")
logger.debug("Using mtcnn kwargs: %s", kwargs) logger.debug("Using mtcnn kwargs: %s", kwargs)
return kwargs return kwargs
@ -72,7 +75,6 @@ class Detect(Detector):
super().initialize(*args, **kwargs) super().initialize(*args, **kwargs)
logger.info("Initializing MTCNN Detector...") logger.info("Initializing MTCNN Detector...")
is_gpu = False is_gpu = False
self.kwargs = kwargs["mtcnn_kwargs"]
# Must import tensorflow inside the spawned process # Must import tensorflow inside the spawned process
# for Windows machines # for Windows machines

View file

@ -1,187 +0,0 @@
# Based on the https://github.com/shaoanlu/faceswap-GAN repo (master/temp/faceswap_GAN_keras.ipynb)
import logging
from keras.models import Model
from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.applications import *
from keras.optimizers import Adam
from lib.PixelShuffler import PixelShuffler
from .instance_normalization import InstanceNormalization
from lib.utils import backup_file
from keras.utils import multi_gpu_model
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
hdf = {'netGAH5': 'netGA_GAN.h5',
'netGBH5': 'netGB_GAN.h5',
'netDAH5': 'netDA_GAN.h5',
'netDBH5': 'netDB_GAN.h5'}
def __conv_init(a):
logger.info("conv_init %s", a)
k = RandomNormal(0, 0.02)(a) # for convolution kernel
k.conv_weight = True
return k
#def batchnorm():
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
def inst_norm():
return InstanceNormalization()
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02) # for batch normalization
class GANModel():
img_size = 64
channels = 3
img_shape = (img_size, img_size, channels)
encoded_dim = 1024
nc_in = 3 # number of input channels of generators
nc_D_inp = 6 # number of input channels of discriminators
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
optimizer = Adam(1e-4, 0.5)
# Build and compile the discriminator
self.netDA, self.netDB = self.build_discriminator()
# Build and compile the generator
self.netGA, self.netGB = self.build_generator()
def converter(self, swap):
predictor = self.netGB if not swap else self.netGA
return lambda img: predictor.predict(img)
def build_generator(self):
def conv_block(input_tensor, f):
x = input_tensor
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = Activation("relu")(x)
return x
def res_block(input_tensor, f):
x = input_tensor
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = add([x, input_tensor])
x = LeakyReLU(alpha=0.2)(x)
return x
def upscale_ps(filters, use_instance_norm=True):
def block(x):
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(nc_in=3, input_size=64):
inp = Input(shape=(input_size, input_size, nc_in))
x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
x = conv_block(x,128)
x = conv_block(x,256)
x = conv_block(x,512)
x = conv_block(x,1024)
x = Dense(1024)(Flatten()(x))
x = Dense(4*4*1024)(x)
x = Reshape((4, 4, 1024))(x)
out = upscale_ps(512)(x)
return Model(inputs=inp, outputs=out)
def Decoder_ps(nc_in=512, input_size=8):
input_ = Input(shape=(input_size, input_size, nc_in))
x = input_
x = upscale_ps(256)(x)
x = upscale_ps(128)(x)
x = upscale_ps(64)(x)
x = res_block(x, 64)
x = res_block(x, 64)
#x = Conv2D(4, kernel_size=5, padding='same')(x)
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
out = concatenate([alpha, rgb])
return Model(input_, out )
encoder = Encoder()
decoder_A = Decoder_ps()
decoder_B = Decoder_ps()
x = Input(shape=self.img_shape)
netGA = Model(x, decoder_A(encoder(x)))
netGB = Model(x, decoder_B(encoder(x)))
self.netGA_sm = netGA
self.netGB_sm = netGB
try:
netGA.load_weights(str(self.model_dir / hdf['netGAH5']))
netGB.load_weights(str(self.model_dir / hdf['netGBH5']))
logger.info("Generator models loaded.")
except:
logger.info("Generator weights files not found.")
pass
if self.gpus > 1:
netGA = multi_gpu_model( self.netGA_sm , self.gpus)
netGB = multi_gpu_model( self.netGB_sm , self.gpus)
return netGA, netGB
def build_discriminator(self):
def conv_block_d(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
return x
def Discriminator(nc_in, input_size=64):
inp = Input(shape=(input_size, input_size, nc_in))
#x = GaussianNoise(0.05)(inp)
x = conv_block_d(inp, 64, False)
x = conv_block_d(x, 128, False)
x = conv_block_d(x, 256, False)
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
return Model(inputs=[inp], outputs=out)
netDA = Discriminator(self.nc_D_inp)
netDB = Discriminator(self.nc_D_inp)
try:
netDA.load_weights(str(self.model_dir / hdf['netDAH5']))
netDB.load_weights(str(self.model_dir / hdf['netDBH5']))
logger.info("Discriminator models loaded.")
except:
logger.info("Discriminator weights files not found.")
pass
return netDA, netDB
def load(self, swapped):
if swapped:
logger.warning("swapping not supported on GAN")
# TODO load is done in __init__ => look how to swap if possible
return True
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
if self.gpus > 1:
self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5']))
self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5']))
else:
self.netGA.save_weights(str(self.model_dir / hdf['netGAH5']))
self.netGB.save_weights(str(self.model_dir / hdf['netGBH5']))
self.netDA.save_weights(str(self.model_dir / hdf['netDAH5']))
self.netDB.save_weights(str(self.model_dir / hdf['netDBH5']))
logger.info("Models saved.")

View file

@ -1,260 +0,0 @@
import time
import cv2
import numpy as np
from keras.layers import *
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam
from keras import backend as K
from lib.training_data import TrainingDataGenerator, stack_images
class GANTrainingDataGenerator(TrainingDataGenerator):
def __init__(self, random_transform_args, coverage, scale, zoom):
super().__init__(random_transform_args, coverage, scale, zoom)
def color_adjust(self, img):
return img / 255.0 * 2 - 1
class Trainer():
random_transform_args = {
'rotation_range': 20,
'zoom_range': 0.1,
'shift_range': 0.05,
'random_flip': 0.5,
}
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
K.set_learning_phase(1)
assert batch_size % 2 == 0, "batch_size must be an even number"
self.batch_size = batch_size
self.model = model
self.use_lsgan = True
self.use_mixup = True
self.mixup_alpha = 0.2
self.use_perceptual_loss = perceptual_loss
self.use_instancenorm = False
self.lrD = 1e-4 # Discriminator learning rate
self.lrG = 1e-4 # Generator learning rate
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 1)
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.setup()
def setup(self):
distorted_A, fake_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
distorted_B, fake_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
real_A = Input(shape=self.model.img_shape)
real_B = Input(shape=self.model.img_shape)
if self.use_lsgan:
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
else:
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
# ========== Define Perceptual Loss Model==========
if self.use_perceptual_loss:
from keras.models import Model
from keras_vggface.vggface import VGGFace
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
vggface.trainable = False
out_size55 = vggface.layers[36].output
out_size28 = vggface.layers[78].output
out_size7 = vggface.layers[-2].output
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
vggface_feat.trainable = False
else:
vggface_feat = None
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, distorted_A, vggface_feat)
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, distorted_B, vggface_feat)
loss_GA += 1e-3 * K.mean(K.abs(mask_A))
loss_GB += 1e-3 * K.mean(K.abs(mask_B))
w_fo = 0.01
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
weightsDA = self.model.netDA.trainable_weights
weightsGA = self.model.netGA.trainable_weights
weightsDB = self.model.netDB.trainable_weights
weightsGB = self.model.netGB.trainable_weights
# Adam(..).get_updates(...)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
def first_order(self, x, axis=1):
img_nrows = x.shape[1]
img_ncols = x.shape[2]
if axis == 1:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
elif axis == 2:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
else:
return None
def train_one_step(self, iter, viewer):
# ---------------------
# Train Discriminators
# ---------------------
# Select a random half batch of images
epoch, warped_A, target_A = next(self.train_batchA)
epoch, warped_B, target_B = next(self.train_batchB)
# Train dicriminators for one batch
errDA = self.netDA_train([warped_A, target_A])
errDB = self.netDB_train([warped_B, target_B])
# Train generators for one batch
errGA = self.netGA_train([warped_A, target_A])
errGB = self.netGB_train([warped_B, target_B])
# For calculating average losses
self.errDA_sum += errDA[0]
self.errDB_sum += errDB[0]
self.errGA_sum += errGA[0]
self.errGB_sum += errGB[0]
self.avg_counter += 1
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
end='\r')
if viewer is not None:
self.show_sample(viewer)
def cycle_variables(self, netG):
distorted_input = netG.inputs[0]
fake_output = netG.outputs[0]
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
fn_generate = K.function([distorted_input], [masked_fake_output])
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
fn_bgr = K.function([distorted_input], [rgb])
return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None):
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
fake = alpha * fake_rgb + (1-alpha) * distorted
if self.use_mixup:
dist = Beta(self.mixup_alpha, self.mixup_alpha)
lam = dist.sample()
# ==========
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
# ==========
output_mixup = netD(mixup)
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
output_fake = netD(concatenate([fake, distorted])) # dummy
loss_G = .5 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
else:
output_real = netD(concatenate([real, distorted])) # positive sample
output_fake = netD(concatenate([fake, distorted])) # negative sample
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
loss_D = loss_D_real + loss_D_fake
loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake))
# ==========
loss_G += K.mean(K.abs(fake_rgb - real))
# ==========
# Edge loss (similar with total variation loss)
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=1) - self.first_order(real, axis=1)))
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=2) - self.first_order(real, axis=2)))
# Perceptual Loss
if not vggface_feat is None:
def preprocess_vggface(x):
x = (x + 1)/2 * 255 # channel order: BGR
#x[..., 0] -= 93.5940
#x[..., 1] -= 104.7624
#x[..., 2] -= 129.
x -= [91.4953, 103.8827, 131.0912]
return x
pl_params = (0.011, 0.11, 0.1919)
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
# ==========
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
# ==========
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
return loss_D, loss_G
def show_sample(self, display_fn):
_, wA, tA = next(self.train_batchA)
_, wB, tB = next(self.train_batchB)
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked")
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw")
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
# Reset the averages
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.avg_counter = 0
def showG(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
], axis=1 )
figure_B = np.stack([
test_B,
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure
def showG_mask(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
], axis=1 )
figure_B = np.stack([
test_B,
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure

View file

@ -1,7 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://github.com/shaoanlu/"""
__version__ = '0.1.0'
from .Model import GANModel as Model
from .Trainer import Trainer

View file

@ -1,145 +0,0 @@
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

View file

@ -1,204 +0,0 @@
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
import logging
from keras.models import Model
from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.applications import *
from keras.optimizers import Adam
from lib.PixelShuffler import PixelShuffler
from .instance_normalization import InstanceNormalization
from lib.utils import backup_file
from keras.utils import multi_gpu_model
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
hdf = {'netGAH5':'netGA_GAN128.h5',
'netGBH5': 'netGB_GAN128.h5',
'netDAH5': 'netDA_GAN128.h5',
'netDBH5': 'netDB_GAN128.h5'}
def __conv_init(a):
logger.info("conv_init %s", a)
k = RandomNormal(0, 0.02)(a) # for convolution kernel
k.conv_weight = True
return k
#def batchnorm():
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
def inst_norm():
return InstanceNormalization()
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02) # for batch normalization
class GANModel():
img_size = 128
channels = 3
img_shape = (img_size, img_size, channels)
encoded_dim = 1024
nc_in = 3 # number of input channels of generators
nc_D_inp = 6 # number of input channels of discriminators
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
optimizer = Adam(1e-4, 0.5)
# Build and compile the discriminator
self.netDA, self.netDB = self.build_discriminator()
# Build and compile the generator
self.netGA, self.netGB = self.build_generator()
def converter(self, swap):
predictor = self.netGB if not swap else self.netGA
return lambda img: predictor.predict(img)
def build_generator(self):
def conv_block(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = SeparableConv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
if use_instance_norm:
x = inst_norm()(x)
x = Activation("relu")(x)
return x
def res_block(input_tensor, f, dilation=1):
x = input_tensor
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
x = add([x, input_tensor])
#x = LeakyReLU(alpha=0.2)(x)
return x
def upscale_ps(filters, use_instance_norm=True):
def block(x, use_instance_norm=use_instance_norm):
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
if use_instance_norm:
x = inst_norm()(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(nc_in=3, input_size=128):
inp = Input(shape=(input_size, input_size, nc_in))
x = Conv2D(32, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
x = conv_block(x,64, use_instance_norm=False)
x = conv_block(x,128)
x = conv_block(x,256)
x = conv_block(x,512)
x = conv_block(x,1024)
x = Dense(1024)(Flatten()(x))
x = Dense(4*4*1024)(x)
x = Reshape((4, 4, 1024))(x)
out = upscale_ps(512)(x)
return Model(inputs=inp, outputs=out)
def Decoder_ps(nc_in=512, input_size=8):
input_ = Input(shape=(input_size, input_size, nc_in))
x = input_
x = upscale_ps(256)(x)
x = upscale_ps(128)(x)
x = upscale_ps(64)(x)
x = res_block(x, 64, dilation=2)
out64 = Conv2D(64, kernel_size=3, padding='same')(x)
out64 = LeakyReLU(alpha=0.1)(out64)
out64 = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(out64)
x = upscale_ps(32)(x)
x = res_block(x, 32)
x = res_block(x, 32)
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
out = concatenate([alpha, rgb])
return Model(input_, [out, out64] )
encoder = Encoder()
decoder_A = Decoder_ps()
decoder_B = Decoder_ps()
x = Input(shape=self.img_shape)
netGA = Model(x, decoder_A(encoder(x)))
netGB = Model(x, decoder_B(encoder(x)))
netGA.output_names = ["netGA_out_1", "netGA_out_2"] # Workarounds till https://github.com/keras-team/keras/issues/8962 is fixed.
netGB.output_names = ["netGB_out_1", "netGB_out_2"] #
self.netGA_sm = netGA
self.netGB_sm = netGB
try:
netGA.load_weights(str(self.model_dir / hdf['netGAH5']))
netGB.load_weights(str(self.model_dir / hdf['netGBH5']))
logger.info("Generator models loaded.")
except:
logger.info("Generator weights files not found.")
pass
if self.gpus > 1:
netGA = multi_gpu_model( self.netGA_sm , self.gpus)
netGB = multi_gpu_model( self.netGB_sm , self.gpus)
return netGA, netGB
def build_discriminator(self):
def conv_block_d(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
if use_instance_norm:
x = inst_norm()(x)
x = LeakyReLU(alpha=0.2)(x)
return x
def Discriminator(nc_in, input_size=128):
inp = Input(shape=(input_size, input_size, nc_in))
#x = GaussianNoise(0.05)(inp)
x = conv_block_d(inp, 64, False)
x = conv_block_d(x, 128, True)
x = conv_block_d(x, 256, True)
x = conv_block_d(x, 512, True)
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
return Model(inputs=[inp], outputs=out)
netDA = Discriminator(self.nc_D_inp)
netDB = Discriminator(self.nc_D_inp)
try:
netDA.load_weights(str(self.model_dir / hdf['netDAH5']))
netDB.load_weights(str(self.model_dir / hdf['netDBH5']))
logger.info("Discriminator models loaded.")
except:
logger.info("Discriminator weights files not found.")
pass
return netDA, netDB
def load(self, swapped):
if swapped:
logger.warning("swapping not supported on GAN")
# TODO load is done in __init__ => look how to swap if possible
return True
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
if self.gpus > 1:
self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5']))
self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5']))
else:
self.netGA.save_weights(str(self.model_dir / hdf['netGAH5']))
self.netGB.save_weights(str(self.model_dir / hdf['netGBH5']))
self.netDA.save_weights(str(self.model_dir / hdf['netDAH5']))
self.netDB.save_weights(str(self.model_dir / hdf['netDBH5']))
logger.info("Models saved.")

View file

@ -1,263 +0,0 @@
import time
import cv2
import numpy as np
from keras.layers import *
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam
from keras import backend as K
from lib.training_data import TrainingDataGenerator, stack_images
class GANTrainingDataGenerator(TrainingDataGenerator):
def __init__(self, random_transform_args, coverage, scale, zoom):
super().__init__(random_transform_args, coverage, scale, zoom)
def color_adjust(self, img):
return img / 255.0 * 2 - 1
class Trainer():
random_transform_args = {
'rotation_range': 20,
'zoom_range': 0.1,
'shift_range': 0.05,
'random_flip': 0.5,
}
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
K.set_learning_phase(1)
assert batch_size % 2 == 0, "batch_size must be an even number"
self.batch_size = batch_size
self.model = model
self.use_lsgan = True
self.use_mixup = True
self.mixup_alpha = 0.2
self.use_perceptual_loss = perceptual_loss
self.use_mask_refinement = False #OPTIONAL After 15k iteration**
self.lrD = 1e-4 # Discriminator learning rate
self.lrG = 1e-4 # Generator learning rate
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 2)
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.setup()
def setup(self):
distorted_A, fake_A, fake_sz64_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
distorted_B, fake_B, fake_sz64_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
real_A = Input(shape=self.model.img_shape)
real_B = Input(shape=self.model.img_shape)
if self.use_lsgan:
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
else:
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
# ========== Define Perceptual Loss Model==========
if self.use_perceptual_loss:
from keras.models import Model
from keras_vggface.vggface import VGGFace
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
vggface.trainable = False
out_size55 = vggface.layers[36].output
out_size28 = vggface.layers[78].output
out_size7 = vggface.layers[-2].output
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
vggface_feat.trainable = False
else:
vggface_feat = None
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, fake_sz64_A, distorted_A, vggface_feat)
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, fake_sz64_B, distorted_B, vggface_feat)
if self.use_mask_refinement:
loss_GA += 1e-3 * K.mean(K.square(mask_A))
loss_GB += 1e-3 * K.mean(K.square(mask_B))
else:
loss_GA += 3e-3 * K.mean(K.abs(mask_A))
loss_GB += 3e-3 * K.mean(K.abs(mask_B))
w_fo = 0.01
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
weightsDA = self.model.netDA.trainable_weights
weightsGA = self.model.netGA.trainable_weights
weightsDB = self.model.netDB.trainable_weights
weightsGB = self.model.netGB.trainable_weights
# Adam(..).get_updates(...)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
def first_order(self, x, axis=1):
img_nrows = x.shape[1]
img_ncols = x.shape[2]
if axis == 1:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
elif axis == 2:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
else:
return None
def train_one_step(self, iter, viewer):
# ---------------------
# Train Discriminators
# ---------------------
# Select a random half batch of images
epoch, warped_A, target_A = next(self.train_batchA)
epoch, warped_B, target_B = next(self.train_batchB)
# Train dicriminators for one batch
errDA = self.netDA_train([warped_A, target_A])
errDB = self.netDB_train([warped_B, target_B])
# Train generators for one batch
errGA = self.netGA_train([warped_A, target_A])
errGB = self.netGB_train([warped_B, target_B])
# For calculating average losses
self.errDA_sum += errDA[0]
self.errDB_sum += errDB[0]
self.errGA_sum += errGA[0]
self.errGB_sum += errGB[0]
self.avg_counter += 1
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
end='\r')
if viewer is not None:
self.show_sample(viewer)
def cycle_variables(self, netG):
distorted_input = netG.inputs[0]
fake_output = netG.outputs[0]
fake_output64 = netG.outputs[1]
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
fn_generate = K.function([distorted_input], [masked_fake_output])
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
fn_bgr = K.function([distorted_input], [rgb])
return distorted_input, fake_output, fake_output64, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None):
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
fake = alpha * fake_rgb + (1-alpha) * distorted
if self.use_mixup:
dist = Beta(self.mixup_alpha, self.mixup_alpha)
lam = dist.sample()
# ==========
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
# ==========
output_mixup = netD(mixup)
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
#output_fake = netD(concatenate([fake, distorted])) # dummy
loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
else:
output_real = netD(concatenate([real, distorted])) # positive sample
output_fake = netD(concatenate([fake, distorted])) # negative sample
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
loss_D = loss_D_real + loss_D_fake
loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake))
# ==========
if self.use_mask_refinement:
loss_G += K.mean(K.abs(fake - real))
else:
loss_G += K.mean(K.abs(fake_rgb - real))
loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
# ==========
# Perceptual Loss
if not vggface_feat is None:
def preprocess_vggface(x):
x = (x + 1)/2 * 255 # channel order: BGR
x -= [93.5940, 104.7624, 129.]
return x
pl_params = (0.02, 0.3, 0.5)
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
# ==========
if self.use_mask_refinement:
fake_sz224 = tf.image.resize_images(fake, [224, 224])
else:
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
# ==========
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
return loss_D, loss_G
def show_sample(self, display_fn):
_, wA, tA = next(self.train_batchA)
_, wB, tB = next(self.train_batchB)
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked")
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw")
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
# Reset the averages
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.avg_counter = 0
def showG(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
], axis=1 )
figure_B = np.stack([
test_B,
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure
def showG_mask(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
], axis=1 )
figure_B = np.stack([
test_B,
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure

View file

@ -1,7 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://github.com/shaoanlu/"""
__version__ = '0.1.0'
from .Model import GANModel as Model
from .Trainer import Trainer

View file

@ -1,145 +0,0 @@
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

View file

@ -1,53 +0,0 @@
# Improved-AutoEncoder base classes
import logging
from lib.utils import backup_file
hdf = {'encoderH5': 'IAE_encoder.h5',
'decoderH5': 'IAE_decoder.h5',
'inter_AH5': 'IAE_inter_A.h5',
'inter_BH5': 'IAE_inter_B.h5',
'inter_bothH5': 'IAE_inter_both.h5'}
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class AutoEncoder:
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
self.encoder = self.Encoder()
self.decoder = self.Decoder()
self.inter_A = self.Intermidiate()
self.inter_B = self.Intermidiate()
self.inter_both = self.Intermidiate()
self.initModel()
def load(self, swapped):
(face_A,face_B) = (hdf['inter_AH5'], hdf['inter_BH5']) if not swapped else (hdf['inter_BH5'], hdf['inter_AH5'])
try:
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder.load_weights(str(self.model_dir / hdf['decoderH5']))
self.inter_both.load_weights(str(self.model_dir / hdf['inter_bothH5']))
self.inter_A.load_weights(str(self.model_dir / face_A))
self.inter_B.load_weights(str(self.model_dir / face_B))
logger.info('loaded model weights')
return True
except Exception:
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
return False
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder.save_weights(str(self.model_dir / hdf['decoderH5']))
self.inter_both.save_weights(str(self.model_dir / hdf['inter_bothH5']))
self.inter_A.save_weights(str(self.model_dir / hdf['inter_AH5']))
self.inter_B.save_weights(str(self.model_dir / hdf['inter_BH5']))
logger.info('saved model weights')

View file

@ -1,77 +0,0 @@
# Improved autoencoder for faceswap.
from keras.models import Model as KerasModel
from keras.layers import Input, Dense, Flatten, Reshape, Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from .AutoEncoder import AutoEncoder
from lib.PixelShuffler import PixelShuffler
from keras.utils import multi_gpu_model
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 1024
class Model(AutoEncoder):
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=IMAGE_SHAPE)
self.autoencoder_A = KerasModel(x, self.decoder(Concatenate()([self.inter_A(self.encoder(x)), self.inter_both(self.encoder(x))])))
self.autoencoder_B = KerasModel(x, self.decoder(Concatenate()([self.inter_B(self.encoder(x)), self.inter_both(self.encoder(x))])))
if self.gpus > 1:
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
def converter(self, swap):
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
return lambda img: autoencoder.predict(img)
def conv(self, filters):
def block(x):
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(0.1)(x)
return x
return block
def upscale(self, filters):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(self):
input_ = Input(shape=IMAGE_SHAPE)
x = input_
x = self.conv(128)(x)
x = self.conv(256)(x)
x = self.conv(512)(x)
x = self.conv(1024)(x)
x = Flatten()(x)
return KerasModel(input_, x)
def Intermidiate(self):
input_ = Input(shape=(None, 4 * 4 * 1024))
x = input_
x = Dense(ENCODER_DIM)(x)
x = Dense(4 * 4 * int(ENCODER_DIM/2))(x)
x = Reshape((4, 4, int(ENCODER_DIM/2)))(x)
return KerasModel(input_, x)
def Decoder(self):
input_ = Input(shape=(4, 4, ENCODER_DIM))
x = input_
x = self.upscale(512)(x)
x = self.upscale(256)(x)
x = self.upscale(128)(x)
x = self.upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(input_, x)

View file

@ -1,51 +0,0 @@
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
class Trainer():
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def __init__(self, model, fn_A, fn_B, batch_size, *args):
self.batch_size = batch_size
self.model = model
generator = TrainingDataGenerator(self.random_transform_args, 160)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
figure = numpy.concatenate([figure_A, figure_B], axis=0)
figure = figure.reshape((4, 7) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')

View file

@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """acsaga"""
__version__ = '0.1.0'
from .Model import Model
from .Trainer import Trainer
from .AutoEncoder import AutoEncoder

View file

@ -1,61 +0,0 @@
# AutoEncoder base classes
import logging
from lib.utils import backup_file
hdf = {'encoderH5': 'lowmem_encoder.h5',
'decoder_AH5': 'lowmem_decoder_A.h5',
'decoder_BH5': 'lowmem_decoder_B.h5'}
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
#Part of Filename migration, should be remopved some reasonable time after first added
import os.path
old_encoderH5 = 'encoder.h5'
old_decoder_AH5 = 'decoder_A.h5'
old_decoder_BH5 = 'decoder_B.h5'
#End filename migration
class AutoEncoder:
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
(face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
try:
#Part of Filename migration, should be remopved some reasonable time after first added
if os.path.isfile(str(self.model_dir / old_encoderH5)):
logger.info('Migrating to new filenames:')
if os.path.isfile(str(self.model_dir / hdf['encoderH5'])) is not True:
os.rename(str(self.model_dir / old_decoder_AH5), str(self.model_dir / hdf['decoder_AH5']))
os.rename(str(self.model_dir / old_decoder_BH5), str(self.model_dir / hdf['decoder_BH5']))
os.rename(str(self.model_dir / old_encoderH5), str(self.model_dir / hdf['encoderH5']))
logger.info('Complete')
else:
logger.warning('Failed due to existing files in folder. Loading already migrated files')
#End filename migration
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.load_weights(str(self.model_dir / face_A))
self.decoder_B.load_weights(str(self.model_dir / face_B))
logger.info('loaded model weights')
return True
except Exception as e:
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
return False
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
logger.info('saved model weights')

View file

@ -1,70 +0,0 @@
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
from keras.models import Model as KerasModel
from keras.layers import Input, Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from .AutoEncoder import AutoEncoder
from lib.PixelShuffler import PixelShuffler
from keras.utils import multi_gpu_model
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 512
class Model(AutoEncoder):
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=IMAGE_SHAPE)
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
if self.gpus > 1:
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
def converter(self, swap):
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
return lambda img: autoencoder.predict(img)
def conv(self, filters):
def block(x):
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(0.1)(x)
return x
return block
def upscale(self, filters):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(self):
input_ = Input(shape=IMAGE_SHAPE)
x = input_
x = self.conv(128)(x)
x = self.conv(256)(x)
x = self.conv(512)(x)
x = Dense(ENCODER_DIM)(Flatten()(x))
x = Dense(4 * 4 * 1024)(x)
x = Reshape((4, 4, 1024))(x)
x = self.upscale(512)(x)
return KerasModel(input_, x)
def Decoder(self):
input_ = Input(shape=(8, 8, 512))
x = input_
x = self.upscale(256)(x)
x = self.upscale(128)(x)
x = self.upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(input_, x)

View file

@ -1,56 +0,0 @@
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
class Trainer():
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def __init__(self, model, fn_A, fn_B, batch_size, *args):
self.batch_size = batch_size
self.model = model
generator = TrainingDataGenerator(self.random_transform_args, 160)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
if test_A.shape[0] % 2 == 1:
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
figure = numpy.concatenate([figure_A, figure_B], axis=0)
w = 4
h = int( figure.shape[0] / w)
figure = figure.reshape((w, h) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')

View file

@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://reddit.com/u/deepfakes/"""
__version__ = '0.1.0'
from .Model import Model
from .Trainer import Trainer
from .AutoEncoder import AutoEncoder

View file

@ -1,77 +0,0 @@
# AutoEncoder base classes
import logging
from lib.utils import backup_file
from lib import Serializer
from json import JSONDecodeError
hdf = {'encoderH5': 'encoder.h5',
'decoder_AH5': 'decoder_A.h5',
'decoder_BH5': 'decoder_B.h5',
'state': 'state'}
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class AutoEncoder:
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
serializer = Serializer.get_serializer('json')
state_fn = ".".join([hdf['state'], serializer.ext])
try:
with open(str(self.model_dir / state_fn), 'rb') as fp:
state = serializer.unmarshal(fp.read().decode('utf-8'))
self._epoch_no = state['epoch_no']
except IOError as e:
logger.warning('Error loading training info: %s', str(e.strerror))
self._epoch_no = 0
except JSONDecodeError as e:
logger.warning('Error loading training info: %s', str(e.msg))
self._epoch_no = 0
(face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
try:
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.load_weights(str(self.model_dir / face_A))
self.decoder_B.load_weights(str(self.model_dir / face_B))
logger.info('loaded model weights')
return True
except Exception as e:
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
return False
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
logger.info('saved model weights')
serializer = Serializer.get_serializer('json')
state_fn = ".".join([hdf['state'], serializer.ext])
state_dir = str(self.model_dir / state_fn)
try:
with open(state_dir, 'wb') as fp:
state_json = serializer.marshal({
'epoch_no' : self.epoch_no
})
fp.write(state_json.encode('utf-8'))
except IOError as e:
logger.error(e.strerror)
@property
def epoch_no(self):
"Get current training epoch number"
return self._epoch_no

View file

@ -1,71 +0,0 @@
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
from keras.models import Model as KerasModel
from keras.layers import Input, Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from .AutoEncoder import AutoEncoder
from lib.PixelShuffler import PixelShuffler
from keras.utils import multi_gpu_model
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 1024
class Model(AutoEncoder):
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=IMAGE_SHAPE)
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
if self.gpus > 1:
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
def converter(self, swap):
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
return lambda img: autoencoder.predict(img)
def conv(self, filters):
def block(x):
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(0.1)(x)
return x
return block
def upscale(self, filters):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(self):
input_ = Input(shape=IMAGE_SHAPE)
x = input_
x = self.conv(128)(x)
x = self.conv(256)(x)
x = self.conv(512)(x)
x = self.conv(1024)(x)
x = Dense(ENCODER_DIM)(Flatten()(x))
x = Dense(4 * 4 * 1024)(x)
x = Reshape((4, 4, 1024))(x)
x = self.upscale(512)(x)
return KerasModel(input_, x)
def Decoder(self):
input_ = Input(shape=(8, 8, 512))
x = input_
x = self.upscale(256)(x)
x = self.upscale(128)(x)
x = self.upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(input_, x)

View file

@ -1,59 +0,0 @@
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
class Trainer():
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def __init__(self, model, fn_A, fn_B, batch_size, *args):
self.batch_size = batch_size
self.model = model
generator = TrainingDataGenerator(self.random_transform_args, 160)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
self.model._epoch_no += 1
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), self.model.epoch_no, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
if test_A.shape[0] % 2 == 1:
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
figure = numpy.concatenate([figure_A, figure_B], axis=0)
w = 4
h = int( figure.shape[0] / w)
figure = figure.reshape((w, h) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')

View file

@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://reddit.com/u/deepfakes/"""
__version__ = '0.1.0'
from .Model import Model
from .Trainer import Trainer
from .AutoEncoder import AutoEncoder

View file

@ -1,312 +0,0 @@
#!/usr/bin/python3
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
# Based on https://github.com/iperov/OpenDeepFaceSwap for Decoder multiple res block chain
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
import enum
import logging
import os
import sys
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from keras.initializers import RandomNormal
from keras.layers import Input, Dense, Flatten, Reshape
from keras.layers import SeparableConv2D, add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation
from keras.models import Model as KerasModel
from keras.optimizers import Adam
from keras.utils import multi_gpu_model
from lib.PixelShuffler import PixelShuffler
import lib.Serializer
from lib.utils import backup_file
from . import __version__
from .instance_normalization import InstanceNormalization
if isinstance(__version__, (list, tuple)):
version_str = ".".join([str(n) for n in __version__[1:]])
else:
version_str = __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
mswindows = sys.platform=="win32"
class EncoderType(enum.Enum):
ORIGINAL = "original"
SHAOANLU = "shaoanlu"
_kern_init = RandomNormal(0, 0.02)
def inst_norm():
return InstanceNormalization()
ENCODER = EncoderType.ORIGINAL
hdf = {'encoderH5': 'encoder_{version_str}{ENCODER.value}.h5'.format(**vars()),
'decoder_AH5': 'decoder_A_{version_str}{ENCODER.value}.h5'.format(**vars()),
'decoder_BH5': 'decoder_B_{version_str}{ENCODER.value}.h5'.format(**vars())}
class Model():
ENCODER_DIM = 1024 # dense layer size
IMAGE_SHAPE = 128, 128 # image shape
assert [n for n in IMAGE_SHAPE if n>=16]
IMAGE_WIDTH = max(IMAGE_SHAPE)
IMAGE_WIDTH = (IMAGE_WIDTH//16 + (1 if (IMAGE_WIDTH%16)>=8 else 0))*16
IMAGE_SHAPE = IMAGE_WIDTH, IMAGE_WIDTH, len('BRG') # good to let ppl know what these are...
def __init__(self, model_dir, gpus, encoder_type=ENCODER):
if mswindows:
from ctypes import cdll
mydll = cdll.LoadLibrary("user32.dll")
mydll.SetProcessDPIAware(True)
self._encoder_type = encoder_type
self.model_dir = model_dir
# can't chnage gpu's when the model is initialized no point in making it r/w
self._gpus = gpus
Encoder = getattr(self, "Encoder_{}".format(self._encoder_type.value))
Decoder = getattr(self, "Decoder_{}".format(self._encoder_type.value))
self.encoder = Encoder()
self.decoder_A = Decoder()
self.decoder_B = Decoder()
self.initModel()
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=self.IMAGE_SHAPE)
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
if self.gpus > 1:
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
def load(self, swapped):
model_dir = str(self.model_dir)
from json import JSONDecodeError
face_A, face_B = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals()))
ser = lib.Serializer.get_serializer('json')
try:
with open(state_dir, 'rb') as fp:
state = ser.unmarshal(fp.read().decode('utf-8'))
self._epoch_no = state['epoch_no']
except IOError as e:
logger.warning('Error loading training info: %s', str(e.strerror))
self._epoch_no = 0
except JSONDecodeError as e:
logger.warning('Error loading training info: %s', str(e.msg))
self._epoch_no = 0
try:
self.encoder.load_weights(os.path.join(model_dir, hdf['encoderH5']))
self.decoder_A.load_weights(os.path.join(model_dir, face_A))
self.decoder_B.load_weights(os.path.join(model_dir, face_B))
logger.info('loaded model weights')
return True
except IOError as e:
logger.warning('Error loading training info: %s', str(e.strerror))
except Exception as e:
logger.warning('Error loading training info: %s', str(e))
return False
def converter(self, swap):
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
return autoencoder.predict
def conv(self, filters, kernel_size=5, strides=2, **kwargs):
def block(x):
x = Conv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
x = LeakyReLU(0.1)(x)
return x
return block
def conv_sep(self, filters, kernel_size=5, strides=2, use_instance_norm=True, **kwargs):
def block(x):
x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
x = Activation("relu")(x)
return x
return block
def conv_inst_norm(self, filters, kernel_size=3, strides=2, use_instance_norm=True, **kwargs):
def block(x):
x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
if use_instance_norm:
x = inst_norm()(x)
x = Activation("relu")(x)
return x
return block
def upscale(self, filters, **kwargs):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same',
kernel_initializer=_kern_init)(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def upscale_inst_norm(self, filters, use_instance_norm=True, **kwargs):
def block(x):
x = Conv2D(filters*4, kernel_size=3, use_bias=False,
kernel_initializer=_kern_init, padding='same', **kwargs)(x)
if use_instance_norm:
x = inst_norm()(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder_original(self, **kwargs):
impt = Input(shape=self.IMAGE_SHAPE)
in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4
x = self.conv(in_conv_filters)(impt)
x = self.conv_sep(256)(x)
x = self.conv(512)(x)
x = self.conv_sep(1024)(x)
dense_shape = self.IMAGE_SHAPE[0] // 16
x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x))
x = Dense(dense_shape * dense_shape * 512, kernel_initializer=_kern_init)(x)
x = Reshape((dense_shape, dense_shape, 512))(x)
x = self.upscale(512)(x)
return KerasModel(impt, x, **kwargs)
def Encoder_shaoanlu(self, **kwargs):
impt = Input(shape=self.IMAGE_SHAPE)
in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4
x = Conv2D(in_conv_filters, kernel_size=5, use_bias=False, padding="same")(impt)
x = self.conv_inst_norm(in_conv_filters+32, use_instance_norm=False)(x)
x = self.conv_inst_norm(256)(x)
x = self.conv_inst_norm(512)(x)
x = self.conv_inst_norm(1024)(x)
dense_shape = self.IMAGE_SHAPE[0] // 16
x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x))
x = Dense(dense_shape * dense_shape * 768, kernel_initializer=_kern_init)(x)
x = Reshape((dense_shape, dense_shape, 768))(x)
x = self.upscale(512)(x)
return KerasModel(impt, x, **kwargs)
def Decoder_original(self):
decoder_shape = self.IMAGE_SHAPE[0]//8
inpt = Input(shape=(decoder_shape, decoder_shape, 512))
x = self.upscale(384)(inpt)
x = self.upscale(256-32)(x)
x = self.upscale(self.IMAGE_SHAPE[0])(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(inpt, x)
def Decoder_shaoanlu(self):
decoder_shape = self.IMAGE_SHAPE[0]//8
inpt = Input(shape=(decoder_shape, decoder_shape, 512))
x = self.upscale_inst_norm(512)(inpt)
x = self.upscale_inst_norm(256)(x)
x = self.upscale_inst_norm(self.IMAGE_SHAPE[0])(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(inpt, x)
def save_weights(self):
model_dir = str(self.model_dir)
try:
for model in hdf.values():
backup_file(model_dir, model)
except NameError:
logger.error('backup functionality not available\n')
state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals()))
ser = lib.Serializer.get_serializer('json')
try:
with open(state_dir, 'wb') as fp:
state_json = ser.marshal({
'epoch_no' : self._epoch_no
})
fp.write(state_json.encode('utf-8'))
except IOError as e:
logger.error(e.strerror)
logger.info('saving model weights')
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(getattr(self, mdl_name.rstrip('H5')).save_weights, str(self.model_dir / mdl_H5_fn)) for mdl_name, mdl_H5_fn in hdf.items()]
for future in as_completed(futures):
future.result()
print('.', end='', flush=True)
logger.info('done')
@property
def gpus(self):
return self._gpus
@property
def model_name(self):
try:
return self._model_name
except AttributeError:
import inspect
self._model_name = os.path.dirname(inspect.getmodule(self).__file__).rsplit("_", 1)[1]
return self._model_name
def __str__(self):
return "<{}: ver={}, dense_dim={}, img_size={}>".format(self.model_name,
version_str,
self.ENCODER_DIM,
"x".join([str(n) for n in self.IMAGE_SHAPE[:2]]))

View file

@ -1,84 +0,0 @@
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
TRANSFORM_PRC = 115.
class Trainer():
_random_transform_args = {
'rotation_range': 10 * (TRANSFORM_PRC * .01),
'zoom_range': 0.05 * (TRANSFORM_PRC * .01),
'shift_range': 0.05 * (TRANSFORM_PRC * .01),
'random_flip': 0.4 * (TRANSFORM_PRC * .01),
}
def __init__(self, model, fn_A, fn_B, batch_size, *args):
self.batch_size = batch_size
self.model = model
from timeit import default_timer as clock
self._clock = clock
generator = TrainingDataGenerator(self.random_transform_args, 160, 5, zoom=self.model.IMAGE_SHAPE[0]//64)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
self.generator = generator
def train_one_step(self, iter_no, viewer):
when = self._clock()
_, warped_A, target_A = next(self.images_A)
_, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
self.model._epoch_no += 1
if isinstance(loss_A, (list, tuple)):
print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format(
time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A[1], loss_B[1]),
end='\r')
else:
print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format(
time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:8], target_B[0:8]), "training using {}, bs={}".format(self.model, self.batch_size))
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
if (test_A.shape[0] % 2)!=0:
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
figure = numpy.concatenate([figure_A, figure_B], axis=0)
w = 4
h = int( figure.shape[0] / w)
figure = figure.reshape((w, h) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')
@property
def random_transform_args(self):
return self._random_transform_args

View file

@ -1,8 +0,0 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://reddit.com/u/deepfakes/"""
from ._version import __version__
from .Model import Model
from .Trainer import Trainer

View file

@ -1 +0,0 @@
__version__ = 0, 2, 7

View file

@ -1,145 +0,0 @@
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

View file

@ -23,21 +23,22 @@ class PluginLoader():
@staticmethod @staticmethod
def get_converter(name): def get_converter(name):
""" Return requested converter plugin """ """ Return requested converter plugin """
return PluginLoader._import("Convert", "Convert_{0}".format(name)) return PluginLoader._import("convert", name)
@staticmethod @staticmethod
def get_model(name): def get_model(name):
""" Return requested model plugin """ """ Return requested model plugin """
return PluginLoader._import("Model", "Model_{0}".format(name)) return PluginLoader._import("train.model", name)
@staticmethod @staticmethod
def get_trainer(name): def get_trainer(name):
""" Return requested trainer plugin """ """ Return requested trainer plugin """
return PluginLoader._import("Trainer", "Model_{0}".format(name)) return PluginLoader._import("train.trainer", name)
@staticmethod @staticmethod
def _import(attr, name): def _import(attr, name):
""" Import the plugin's module """ """ Import the plugin's module """
name = name.replace("-", "_")
ttl = attr.split(".")[-1].title() ttl = attr.split(".")[-1].title()
logger.info("Loading %s from %s plugin...", ttl, name.title()) logger.info("Loading %s from %s plugin...", ttl, name.title())
attr = "model" if attr == "Trainer" else attr.lower() attr = "model" if attr == "Trainer" else attr.lower()
@ -48,13 +49,23 @@ class PluginLoader():
@staticmethod @staticmethod
def get_available_models(): def get_available_models():
""" Return a list of available models """ """ Return a list of available models """
models = () modelpath = os.path.join(os.path.dirname(__file__), "train", "model")
modelpath = os.path.join(os.path.dirname(__file__), "model") models = sorted(item.name.replace(".py", "").replace("_", "-")
for modeldir in next(os.walk(modelpath))[1]: for item in os.scandir(modelpath)
if modeldir[0:6].lower() == 'model_': if not item.name.startswith("_")
models += (modeldir[6:],) and item.name.endswith(".py"))
return models return models
@staticmethod
def get_available_converters():
""" Return a list of available converters """
converter_path = os.path.join(os.path.dirname(__file__), "convert")
converters = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(converter_path)
if not item.name.startswith("_")
and item.name.endswith(".py"))
return converters
@staticmethod @staticmethod
def get_available_extractors(extractor_type): def get_available_extractors(extractor_type):
""" Return a list of available models """ """ Return a list of available models """
@ -72,4 +83,4 @@ class PluginLoader():
def get_default_model(): def get_default_model():
""" Return the default model """ """ Return the default model """
models = PluginLoader.get_available_models() models = PluginLoader.get_available_models()
return 'Original' if 'Original' in models else models[0] return 'original' if 'original' in models else models[0]

180
plugins/train/_config.py Normal file
View file

@ -0,0 +1,180 @@
#!/usr/bin/env python3
""" Default configurations for models """
import logging
from lib.config import FaceswapConfig
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
MASK_TYPES = ["none", "dfaker", "dfl_full"]
MASK_INFO = "The mask to be used for training. Select none to not use a mask"
COVERAGE_INFO = ("How much of the extracted image to train on. Generally the model is optimized\n"
"to the default value. Sensible values to use are:"
"\n\t62.5%% spans from eyebrow to eyebrow."
"\n\t75.0%% spans from temple to temple."
"\n\t87.5%% spans from ear to ear."
"\n\t100.0%% is a mugshot.")
class Config(FaceswapConfig):
""" Config File for Models """
def set_defaults(self):
""" Set the default values for config """
logger.debug("Setting defaults")
# << GLOBAL OPTIONS >> #
section = "global"
self.add_section(title=section,
info="Options that apply to all models")
self.add_item(
section=section, title="icnr_init", datatype=bool, default=False,
info="Use ICNR Kernel Initializer for upscaling.\nThis can help reduce the "
"'checkerboard effect' when upscaling the image.")
self.add_item(
section=section, title="subpixel_upscaling", datatype=bool, default=False,
info="Use subpixel upscaling rather than pixel shuffler.\n"
"Might increase speed at cost of VRAM")
self.add_item(
section=section, title="reflect_padding", datatype=bool, default=False,
info="Use reflect padding rather than zero padding.")
self.add_item(
section=section, title="dssim_mask_loss", datatype=bool, default=True,
info="If using a mask, Use DSSIM loss for Mask training rather than Mean Absolute "
"Error\nMay increase overall quality.")
self.add_item(
section=section, title="penalized_mask_loss", datatype=bool, default=True,
info="If using a mask, Use Penalized loss for Mask training. Can stack with DSSIM.\n"
"May increase overall quality.")
# << DFAKER OPTIONS >> #
section = "model.dfaker"
self.add_section(title=section,
info="Dfaker Model (Adapted from https://github.com/dfaker/df)")
self.add_item(
section=section, title="mask_type", datatype=str, default="dfaker",
choices=MASK_TYPES, info=MASK_INFO)
self.add_item(
section=section, title="coverage", datatype=float, default=100.0, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)
# << DFL MODEL OPTIONS >> #
section = "model.dfl_h128"
self.add_section(title=section,
info="DFL H128 Model (Adapted from "
"https://github.com/iperov/DeepFaceLab)")
self.add_item(
section=section, title="lowmem", datatype=bool, default=False,
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
"with a changed lowmem mode are not compatible with each other.")
self.add_item(
section=section, title="mask_type", datatype=str, default="dfl_full",
choices=MASK_TYPES, info=MASK_INFO)
self.add_item(
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)
# << IAE MODEL OPTIONS >> #
section = "model.iae"
self.add_section(title=section,
info="Intermediate Auto Encoder. Based on Original Model, uses "
"intermediate layers to try to better get details")
self.add_item(
section=section, title="dssim_loss", datatype=bool, default=False,
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
"May increase overall quality.")
self.add_item(
section=section, title="mask_type", datatype=str, default="none",
choices=MASK_TYPES, info=MASK_INFO)
self.add_item(
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)
# << ORIGINAL MODEL OPTIONS >> #
section = "model.original"
self.add_section(title=section,
info="Original Faceswap Model")
self.add_item(
section=section, title="lowmem", datatype=bool, default=False,
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
"with a changed lowmem mode are not compatible with each other.")
self.add_item(
section=section, title="dssim_loss", datatype=bool, default=False,
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
"May increase overall quality.")
self.add_item(
section=section, title="mask_type", datatype=str, default="none",
choices=MASK_TYPES, info=MASK_INFO)
self.add_item(
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)
# << UNBALANCED MODEL OPTIONS >> #
section = "model.unbalanced"
self.add_section(title=section,
info="An unbalanced model with adjustable input size options.\n"
"This is an unbalanced model so b>a swaps may not work well")
self.add_item(
section=section, title="lowmem", datatype=bool, default=False,
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
"with a changed lowmem mode are not compatible with each other. NB: lowmem will "
"override cutom nodes and complexity settings.")
self.add_item(
section=section, title="dssim_loss", datatype=bool, default=False,
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
"May increase overall quality.")
self.add_item(
section=section, title="mask_type", datatype=str, default="none",
choices=MASK_TYPES, info=MASK_INFO)
self.add_item(
section=section, title="nodes", datatype=int, default=1024, rounding=64,
min_max=(512, 4096),
info="Number of nodes for decoder. Don't change this unless you "
"know what you are doing!")
self.add_item(
section=section, title="complexity_encoder", datatype=int, default=128,
rounding=16, min_max=(64, 1024),
info="Encoder Convolution Layer Complexity. sensible ranges: "
"128 to 160")
self.add_item(
section=section, title="complexity_decoder_a", datatype=int, default=384,
rounding=16, min_max=(64, 1024),
info="Decoder A Complexity.")
self.add_item(
section=section, title="complexity_decoder_b", datatype=int, default=512,
rounding=16, min_max=(64, 1024),
info="Decoder B Complexity.")
self.add_item(
section=section, title="input_size", datatype=int, default=128,
rounding=64, min_max=(64, 512),
info="Resolution (in pixels) of the image to train on.\n"
"BE AWARE Larger resolution will dramatically increase"
"VRAM requirements.\n"
"Make sure your resolution is divisible by 64 (e.g. 64, 128, 256 etc.).\n"
"NB: Your faceset must be at least 1.6x larger than your required input size.\n"
" (e.g. 160 is the maximum input size for a 256x256 faceset)")
self.add_item(
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)
# << VILLAIN MODEL OPTIONS >> #
section = "model.villain"
self.add_section(title=section,
info="A Higher resolution version of the Original "
"Model by VillainGuy.\n"
"Extremely VRAM heavy. Full model requires 9GB+ for batchsize 16")
self.add_item(
section=section, title="lowmem", datatype=bool, default=False,
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
"with a changed lowmem mode are not compatible with each other.")
self.add_item(
section=section, title="dssim_loss", datatype=bool, default=False,
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
"May increase overall quality.")
self.add_item(
section=section, title="mask_type", datatype=str, default="none",
choices=["none", "dfaker", "dfl_full"],
info="The mask to be used for training. Select none to not use a mask")
self.add_item(
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
min_max=(62.5, 100.0), info=COVERAGE_INFO)

View file

View file

@ -0,0 +1,586 @@
#!/usr/bin/env python3
""" Base class for Models. ALL Models should at least inherit from this class
When inheriting model_data should be a list of NNMeta objects.
See the class for details.
"""
import logging
import os
import sys
import time
from json import JSONDecodeError
from keras import losses
from keras.models import load_model
from keras.optimizers import Adam
from keras.utils import get_custom_objects, multi_gpu_model
from lib import Serializer
from lib.model.losses import DSSIMObjective, PenalizedLoss
from lib.model.nn_blocks import NNBlocks
from lib.multithreading import MultiThread
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG = None
class ModelBase():
""" Base class that all models should inherit from """
def __init__(self,
model_dir,
gpus,
no_logs=False,
warp_to_landmarks=False,
no_flip=False,
training_image_size=256,
alignments_paths=None,
preview_scale=100,
input_shape=None,
encoder_dim=None,
trainer="original",
predict=False):
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
"input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus,
training_image_size, alignments_paths, preview_scale, input_shape,
encoder_dim)
self.predict = predict
self.model_dir = model_dir
self.gpus = gpus
self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"],
use_icnr_init=self.config["icnr_init"],
use_reflect_padding=self.config["reflect_padding"])
self.input_shape = input_shape
self.output_shape = None # set after model is compiled
self.encoder_dim = encoder_dim
self.trainer = trainer
self.state = State(self.model_dir, self.name, no_logs, training_image_size)
self.load_state_info()
self.networks = dict() # Networks for the model
self.predictors = dict() # Predictors for model
self.history = dict() # Loss history per save iteration)
# Training information specific to the model should be placed in this
# dict for reference by the trainer.
self.training_opts = {"alignments": alignments_paths,
"preview_scaling": preview_scale / 100,
"warp_to_landmarks": warp_to_landmarks,
"no_flip": no_flip}
self.build()
self.set_training_data()
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
@property
def config(self):
""" Return config dict for current plugin """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = ".".join(self.__module__.split(".")[-2:])
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name).config_dict
return _CONFIG
@property
def name(self):
""" Set the model name based on the subclass """
basename = os.path.basename(sys.modules[self.__module__].__file__)
retval = os.path.splitext(basename)[0].lower()
logger.debug("model name: '%s'", retval)
return retval
def set_training_data(self):
""" Override to set model specific training data.
super() this method for defaults otherwise be sure to add """
logger.debug("Setting training data")
self.training_opts["training_size"] = self.state.training_size
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
self.training_opts["mask_type"] = self.config.get("mask_type", None)
self.training_opts["coverage_ratio"] = self.config.get("coverage", 62.5) / 100
self.training_opts["preview_images"] = 14
logger.debug("Set training data: %s", self.training_opts)
def build(self):
""" Build the model. Override for custom build methods """
self.add_networks()
self.load_models(swapped=False)
self.build_autoencoders()
self.log_summary()
self.compile_predictors()
def build_autoencoders(self):
""" Override for Model Specific autoencoder builds
NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names
are expected:
face (the input for image)
mask (the input for mask if it is used)
"""
raise NotImplementedError
def add_networks(self):
""" Override to add neural networks """
raise NotImplementedError
def load_state_info(self):
""" Load the input shape from state file if it exists """
logger.debug("Loading Input Shape from State file")
if not self.state.inputs:
logger.debug("No input shapes saved. Using model config")
return
if not self.state.face_shapes:
logger.warning("Input shapes stored in State file, but no matches for 'face'."
"Using model config")
return
input_shape = self.state.face_shapes[0]
logger.debug("Setting input shape from state file: %s", input_shape)
self.input_shape = input_shape
def add_network(self, network_type, side, network):
""" Add a NNMeta object """
logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network)
filename = "{}_{}".format(self.name, network_type.lower())
name = network_type.lower()
if side:
side = side.lower()
filename += "_{}".format(side.upper())
name += "_{}".format(side)
filename += ".h5"
logger.debug("name: '%s', filename: '%s'", name, filename)
self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network)
def add_predictor(self, side, model):
""" Add a predictor to the predictors dictionary """
logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
if self.gpus > 1:
logger.debug("Converting to multi-gpu: side %s", side)
model = multi_gpu_model(model, self.gpus)
self.predictors[side] = model
if not self.state.inputs:
self.store_input_shapes(model)
if not self.output_shape:
self.set_output_shape(model)
def store_input_shapes(self, model):
""" Store the input and output shapes to state """
logger.debug("Adding input shapes to state for model")
inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs}
if not any(inp for inp in inputs.keys() if inp.startswith("face")):
raise ValueError("No input named 'face' was found. Check your input naming. "
"Current input names: {}".format(inputs))
self.state.inputs = inputs
logger.debug("Added input shapes: %s", self.state.inputs)
def set_output_shape(self, model):
""" Set the output shape for use in training and convert """
logger.debug("Setting output shape")
out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs]
if not out:
raise ValueError("No outputs found! Check your model.")
self.output_shape = tuple(out[0])
logger.debug("Added output shape: %s", self.output_shape)
def compile_predictors(self):
""" Compile the predictors """
logger.debug("Compiling Predictors")
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0)
for side, model in self.predictors.items():
loss_names = ["loss"]
loss_funcs = [self.loss_function(side)]
mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
if mask:
loss_names.insert(0, "mask_loss")
loss_funcs.insert(0, self.mask_loss_function(mask[0], side))
model.compile(optimizer=optimizer, loss=loss_funcs)
if len(loss_names) > 1:
loss_names.insert(0, "total_loss")
self.state.add_session_loss_names(side, loss_names)
self.history[side] = list()
logger.debug("Compiled Predictors. Losses: %s", loss_names)
def loss_function(self, side):
""" Set the loss function """
if self.config.get("dssim_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using DSSIM Loss")
loss_func = DSSIMObjective()
else:
if side == "a" and not self.predict:
logger.verbose("Using Mean Absolute Error Loss")
loss_func = losses.mean_absolute_error
logger.debug(loss_func)
return loss_func
def mask_loss_function(self, mask, side):
""" Set the loss function for masks
Side is input so we only log once """
if self.config.get("dssim_mask_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using DSSIM Loss for mask")
mask_loss_func = DSSIMObjective()
else:
if side == "a" and not self.predict:
logger.verbose("Using Mean Absolute Error Loss for mask")
mask_loss_func = losses.mean_absolute_error
if self.config.get("penalized_mask_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using Penalized Loss for mask")
mask_loss_func = PenalizedLoss(mask, mask_loss_func)
logger.debug(mask_loss_func)
return mask_loss_func
def converter(self, swap):
""" Converter for autoencoder models """
logger.debug("Getting Converter: (swap: %s)", swap)
if swap:
retval = self.predictors["a"].predict
else:
retval = self.predictors["b"].predict
logger.debug("Got Converter: %s", retval)
return retval
@property
def iterations(self):
"Get current training iteration number"
return self.state.iterations
def map_models(self, swapped):
""" Map the models for A/B side for swapping """
logger.debug("Map models: (swapped: %s)", swapped)
models_map = {"a": dict(), "b": dict()}
sides = ("a", "b") if not swapped else ("b", "a")
for network in self.networks.values():
if network.side == sides[0]:
models_map["a"][network.type] = network.filename
if network.side == sides[1]:
models_map["b"][network.type] = network.filename
logger.debug("Mapped models: (models_map: %s)", models_map)
return models_map
def log_summary(self):
""" Verbose log the model summaries """
if self.predict:
return
for side in sorted(list(self.predictors.keys())):
logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x))
for name, nnmeta in self.networks.items():
if nnmeta.side is not None and nnmeta.side != side:
continue
logger.verbose("%s:", name.title())
nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x))
def load_models(self, swapped):
""" Load models from file """
logger.debug("Load model: (swapped: %s)", swapped)
model_mapping = self.map_models(swapped)
for network in self.networks.values():
if not network.side:
is_loaded = network.load(predict=self.predict)
else:
is_loaded = network.load(fullpath=model_mapping[network.side][network.type],
predict=self.predict)
if not is_loaded:
break
if is_loaded:
logger.info("Loaded model from disk: '%s'", self.model_dir)
return is_loaded
def save_models(self):
""" Backup and save the models """
logger.debug("Backing up and saving models")
should_backup = self.get_save_averages()
save_threads = list()
for network in self.networks.values():
name = "save_{}".format(network.name)
save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup))
save_threads.append(MultiThread(self.state.save,
name="save_state", should_backup=should_backup))
for thread in save_threads:
thread.start()
for thread in save_threads:
if thread.has_error:
logger.error(thread.errors[0])
thread.join()
# Put in a line break to avoid jumbled console
print("\n")
logger.info("saved models")
def get_save_averages(self):
""" Return the loss averages since last save and reset historical losses
This protects against model corruption by only backing up the model
if any of the loss values have fallen.
TODO This is not a perfect system. If the model corrupts on save_iteration - 1
then model may still backup
"""
logger.debug("Getting Average loss since last save")
avgs = dict()
backup = True
for side, loss in self.history.items():
if not loss:
backup = False
break
avgs[side] = sum(loss) / len(loss)
self.history[side] = list() # Reset historical loss
if not self.state.lowest_avg_loss.get(side, None):
logger.debug("Setting initial save iteration loss average for '%s': %s",
side, avgs[side])
self.state.lowest_avg_loss[side] = avgs[side]
continue
if backup:
# Only run this if backup is true. All losses must have dropped for a valid backup
backup = self.check_loss_drop(side, avgs[side])
logger.debug("Lowest historical save iteration loss average: %s",
self.state.lowest_avg_loss)
logger.debug("Average loss since last save: %s", avgs)
if backup: # Update lowest loss values to the state
for side, avg_loss in avgs.items():
logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
self.state.lowest_avg_loss[side] = avg_loss
logger.debug("Backing up: %s", backup)
return backup
def check_loss_drop(self, side, avg):
""" Check whether total loss has dropped since lowest loss """
if avg < self.state.lowest_avg_loss[side]:
logger.debug("Loss for '%s' has dropped", side)
return True
logger.debug("Loss for '%s' has not dropped", side)
return False
class NNMeta():
""" Class to hold a neural network and it's meta data
filename: The full path and filename of the model file for this network.
type: The type of network. For networks that can be swapped
The type should be identical for the corresponding
A and B networks, and should be unique for every A/B pair.
Otherwise the type should be completely unique.
side: A, B or None. Used to identify which networks can
be swapped.
network: Define network to this.
"""
def __init__(self, filename, network_type, side, network):
logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
"network: %s", self.__class__.__name__, filename, network_type,
side, network)
self.filename = filename
self.type = network_type.lower()
self.side = side
self.name = self.set_name()
self.network = network
self.network.name = self.name
logger.debug("Initialized %s", self.__class__.__name__)
def set_name(self):
""" Set the network name """
name = self.type
if self.side:
name += "_{}".format(self.side)
return name
def load(self, fullpath=None, predict=False):
""" Load model """
fullpath = fullpath if fullpath else self.filename
logger.debug("Loading model: '%s'", fullpath)
try:
network = load_model(self.filename, custom_objects=get_custom_objects())
except ValueError as err:
if str(err).lower().startswith("cannot create group in read only mode"):
self.convert_legacy_weights()
return True
if predict:
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
except OSError as err: # pylint: disable=broad-except
if predict:
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
self.network = network # Update network with saved model
self.network.name = self.type
return True
def save(self, fullpath=None, should_backup=False):
""" Save model """
fullpath = fullpath if fullpath else self.filename
if should_backup:
self.backup(fullpath=fullpath)
logger.debug("Saving model: '%s'", fullpath)
self.network.save(fullpath)
def backup(self, fullpath=None):
""" Backup Model """
origfile = fullpath if fullpath else self.filename
backupfile = origfile + ".bk"
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
if os.path.exists(backupfile):
os.remove(backupfile)
if os.path.exists(origfile):
os.rename(origfile, backupfile)
def convert_legacy_weights(self):
""" Convert legacy weights files to hold the model topology """
logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
self.network.load_weights(self.filename)
self.save(should_backup=False)
self.network.name = self.type
class State():
""" Class to hold the model's current state and autoencoder structure """
def __init__(self, model_dir, model_name, no_logs, training_image_size):
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
"training_image_size: '%s'", self.__class__.__name__, model_dir,
model_name, no_logs, training_image_size)
self.serializer = Serializer.get_serializer("json")
filename = "{}_state.{}".format(model_name, self.serializer.ext)
self.filename = str(model_dir / filename)
self.iterations = 0
self.session_iterations = 0
self.training_size = training_image_size
self.sessions = dict()
self.lowest_avg_loss = dict()
self.inputs = dict()
self.config = dict()
self.load()
self.session_id = self.new_session_id()
self.create_new_session(no_logs)
logger.debug("Initialized %s:", self.__class__.__name__)
@property
def face_shapes(self):
""" Return a list of stored face shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]
@property
def mask_shapes(self):
""" Return a list of stored mask shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]
@property
def loss_names(self):
""" Return the loss names for this session """
return self.sessions[self.session_id]["loss_names"]
@property
def current_session(self):
""" Return the current session dict """
return self.sessions[self.session_id]
def new_session_id(self):
""" Return new session_id """
if not self.sessions:
session_id = 1
else:
session_id = max(int(key) for key in self.sessions.keys()) + 1
logger.debug(session_id)
return session_id
def create_new_session(self, no_logs):
""" Create a new session """
logger.debug("Creating new session. id: %s", self.session_id)
self.sessions[self.session_id] = {"timestamp": time.time(),
"no_logs": no_logs,
"loss_names": dict(),
"batchsize": 0,
"iterations": 0}
def add_session_loss_names(self, side, loss_names):
""" Add the session loss names to the sessions dictionary """
logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
self.sessions[self.session_id]["loss_names"][side] = loss_names
def add_session_batchsize(self, batchsize):
""" Add the session batchsize to the sessions dictionary """
logger.debug("Adding session batchsize: %s", batchsize)
self.sessions[self.session_id]["batchsize"] = batchsize
def increment_iterations(self):
""" Increment total and session iterations """
self.iterations += 1
self.sessions[self.session_id]["iterations"] += 1
def load(self):
""" Load state file """
logger.debug("Loading State")
try:
with open(self.filename, "rb") as inp:
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
self.sessions = state.get("sessions", dict())
self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
self.iterations = state.get("iterations", 0)
self.training_size = state.get("training_size", 256)
self.inputs = state.get("inputs", dict())
self.config = state.get("config", dict())
logger.debug("Loaded state: %s", state)
self.replace_config()
except IOError as err:
logger.warning("No existing state file found. Generating.")
logger.debug("IOError: %s", str(err))
except JSONDecodeError as err:
logger.debug("JSONDecodeError: %s:", str(err))
def save(self, should_backup=False):
""" Save iteration number to state file """
logger.debug("Saving State")
if should_backup:
self.backup()
try:
with open(self.filename, "wb") as out:
state = {"sessions": self.sessions,
"lowest_avg_loss": self.lowest_avg_loss,
"iterations": self.iterations,
"inputs": self.inputs,
"training_size": self.training_size,
"config": _CONFIG}
state_json = self.serializer.marshal(state)
out.write(state_json.encode("utf-8"))
except IOError as err:
logger.error("Unable to save model state: %s", str(err.strerror))
logger.debug("Saved State")
def backup(self):
""" Backup state file """
origfile = self.filename
backupfile = origfile + ".bk"
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
if os.path.exists(backupfile):
os.remove(backupfile)
if os.path.exists(origfile):
os.rename(origfile, backupfile)
def replace_config(self):
""" Replace the loaded config with the one contained within the state file """
global _CONFIG # pylint: disable=global-statement
# Add any new items to state config for legacy purposes
for key, val in _CONFIG.items():
if key not in self.config.keys():
logger.info("Adding new config item to state file: '%s': '%s'", key, val)
self.config[key] = val
logger.debug("Replacing config. Old config: %s", _CONFIG)
_CONFIG = self.config
logger.debug("Replaced config. New config: %s", _CONFIG)
logger.info("Using configuration saved in state file")

View file

@ -0,0 +1,62 @@
#!/usr/bin/env python3
""" DFaker Model
Based on the dfaker model: https://github.com/dfaker """
from keras.initializers import RandomNormal
from keras.layers import Conv2D, Input
from keras.models import Model as KerasModel
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Improved Autoeencoder Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
kwargs["input_shape"] = (64, 64, 3)
kwargs["encoder_dim"] = 1024
self.kernel_initializer = RandomNormal(0, 0.02)
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def build_autoencoders(self):
""" Initialize Dfaker model """
logger.debug("Initializing model")
inputs = [Input(shape=self.input_shape, name="face")]
if self.config.get("mask_type", None):
mask_shape = (self.input_shape[0] * 2, self.input_shape[1] * 2, 1)
inputs.append(Input(shape=mask_shape, name="mask"))
for side in ("a", "b"):
decoder = self.networks["decoder_{}".format(side)].network
output = decoder(self.networks["encoder"].network(inputs[0]))
autoencoder = KerasModel(inputs, output)
self.add_predictor(side, autoencoder)
logger.debug("Initialized model")
def decoder(self):
""" Decoder Network """
input_ = Input(shape=(8, 8, 512))
var_x = input_
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True)
var_x = self.blocks.res_block(var_x, 512, kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True)
var_x = self.blocks.res_block(var_x, 256, kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, 128, res_block_follows=True)
var_x = self.blocks.res_block(var_x, 128, kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, 64)
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = input_
var_y = self.blocks.upscale(var_y, 512)
var_y = self.blocks.upscale(var_y, 256)
var_y = self.blocks.upscale(var_y, 128)
var_y = self.blocks.upscale(var_y, 64)
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
outputs.append(var_y)
return KerasModel([input_], outputs=outputs)

View file

@ -0,0 +1,53 @@
#!/usr/bin/env python3
""" DeepFakesLab H128 Model
Based on https://github.com/iperov/DeepFaceLab
"""
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Low Memory version of Original Faceswap Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
kwargs["input_shape"] = (128, 128, 3)
kwargs["encoder_dim"] = 256 if self.config["lowmem"] else 512
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def encoder(self):
""" DFL H128 Encoder """
input_ = Input(shape=self.input_shape)
var_x = input_
var_x = self.blocks.conv(var_x, 128)
var_x = self.blocks.conv(var_x, 256)
var_x = self.blocks.conv(var_x, 512)
var_x = self.blocks.conv(var_x, 1024)
var_x = Dense(self.encoder_dim)(Flatten()(var_x))
var_x = Dense(8 * 8 * self.encoder_dim)(var_x)
var_x = Reshape((8, 8, self.encoder_dim))(var_x)
var_x = self.blocks.upscale(var_x, self.encoder_dim)
return KerasModel(input_, var_x)
def decoder(self):
""" DFL H128 Decoder """
input_ = Input(shape=(16, 16, self.encoder_dim))
var = input_
var = self.blocks.upscale(var, self.encoder_dim)
var = self.blocks.upscale(var, self.encoder_dim // 2)
var = self.blocks.upscale(var, self.encoder_dim // 4)
# Face
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var)
outputs = [var_x]
# Mask
if self.config.get("mask_type", None):
var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)

View file

@ -0,0 +1,84 @@
#!/usr/bin/env python3
""" Improved autoencoder for faceswap """
from keras.layers import Concatenate, Conv2D, Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from ._base import ModelBase, logger
class Model(ModelBase):
""" Improved Autoeencoder Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
kwargs["input_shape"] = (64, 64, 3)
kwargs["encoder_dim"] = 1024
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def add_networks(self):
""" Add the IAE model weights """
logger.debug("Adding networks")
self.add_network("encoder", None, self.encoder())
self.add_network("decoder", None, self.decoder())
self.add_network("intermediate", "a", self.intermediate())
self.add_network("intermediate", "b", self.intermediate())
self.add_network("inter", None, self.intermediate())
logger.debug("Added networks")
def build_autoencoders(self):
""" Initialize IAE model """
logger.debug("Initializing model")
inputs = [Input(shape=self.input_shape, name="face")]
if self.config.get("mask_type", "none") != "none":
mask_shape = (self.input_shape[:2] + (1, ))
inputs.append(Input(shape=mask_shape, name="mask"))
decoder = self.networks["decoder"].network
encoder = self.networks["encoder"].network
inter_both = self.networks["inter"].network
for side in ("a", "b"):
inter_side = self.networks["intermediate_{}".format(side)].network
output = decoder(Concatenate()([inter_side(encoder(inputs[0])),
inter_both(encoder(inputs[0]))]))
autoencoder = KerasModel(inputs, output)
self.add_predictor(side, autoencoder)
logger.debug("Initialized model")
def encoder(self):
""" Encoder Network """
input_ = Input(shape=self.input_shape)
var_x = input_
var_x = self.blocks.conv(var_x, 128)
var_x = self.blocks.conv(var_x, 266)
var_x = self.blocks.conv(var_x, 512)
var_x = self.blocks.conv(var_x, 1024)
var_x = Flatten()(var_x)
return KerasModel(input_, var_x)
def intermediate(self):
""" Intermediate Network """
input_ = Input(shape=(None, 4 * 4 * 1024))
var_x = input_
var_x = Dense(self.encoder_dim)(var_x)
var_x = Dense(4 * 4 * int(self.encoder_dim/2))(var_x)
var_x = Reshape((4, 4, int(self.encoder_dim/2)))(var_x)
return KerasModel(input_, var_x)
def decoder(self):
""" Decoder Network """
input_ = Input(shape=(4, 4, self.encoder_dim))
var_x = input_
var_x = self.blocks.upscale(var_x, 512)
var_x = self.blocks.upscale(var_x, 256)
var_x = self.blocks.upscale(var_x, 128)
var_x = self.blocks.upscale(var_x, 64)
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var_x)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)

View file

@ -0,0 +1,83 @@
#!/usr/bin/env python3
""" Original Model
Based on the original https://www.reddit.com/r/deepfakes/
code sample + contribs """
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from ._base import ModelBase, logger
class Model(ModelBase):
""" Original Faceswap Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
if "input_shape" not in kwargs:
kwargs["input_shape"] = (64, 64, 3)
if "encoder_dim" not in kwargs:
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def add_networks(self):
""" Add the original model weights """
logger.debug("Adding networks")
self.add_network("decoder", "a", self.decoder())
self.add_network("decoder", "b", self.decoder())
self.add_network("encoder", None, self.encoder())
logger.debug("Added networks")
def build_autoencoders(self):
""" Initialize original model """
logger.debug("Initializing model")
inputs = [Input(shape=self.input_shape, name="face")]
if self.config.get("mask_type", None):
mask_shape = (self.input_shape[:2] + (1, ))
inputs.append(Input(shape=mask_shape, name="mask"))
for side in ("a", "b"):
logger.debug("Adding Autoencoder. Side: %s", side)
decoder = self.networks["decoder_{}".format(side)].network
output = decoder(self.networks["encoder"].network(inputs[0]))
autoencoder = KerasModel(inputs, output)
self.add_predictor(side, autoencoder)
logger.debug("Initialized model")
def encoder(self):
""" Encoder Network """
input_ = Input(shape=self.input_shape)
var_x = input_
var_x = self.blocks.conv(var_x, 128)
var_x = self.blocks.conv(var_x, 256)
var_x = self.blocks.conv(var_x, 512)
if not self.config.get("lowmem", False):
var_x = self.blocks.conv(var_x, 1024)
var_x = Dense(self.encoder_dim)(Flatten()(var_x))
var_x = Dense(4 * 4 * 1024)(var_x)
var_x = Reshape((4, 4, 1024))(var_x)
var_x = self.blocks.upscale(var_x, 512)
return KerasModel(input_, var_x)
def decoder(self):
""" Decoder Network """
input_ = Input(shape=(8, 8, 512))
var_x = input_
var_x = self.blocks.upscale(var_x, 256)
var_x = self.blocks.upscale(var_x, 128)
var_x = self.blocks.upscale(var_x, 64)
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = input_
var_y = self.blocks.upscale(var_y, 256)
var_y = self.blocks.upscale(var_y, 128)
var_y = self.blocks.upscale(var_y, 64)
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)

View file

@ -0,0 +1,130 @@
#!/usr/bin/env python3
""" Unbalanced Model
Based on the original https://www.reddit.com/r/deepfakes/
code sample + contribs """
from keras.initializers import RandomNormal
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape, SpatialDropout2D
from keras.models import Model as KerasModel
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Unbalanced Faceswap Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
self.lowmem = self.config.get("lowmem", False)
kwargs["input_shape"] = (self.config["input_size"], self.config["input_size"], 3)
kwargs["encoder_dim"] = 512 if self.lowmem else self.config["nodes"]
self.kernel_initializer = RandomNormal(0, 0.02)
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def add_networks(self):
""" Add the original model weights """
logger.debug("Adding networks")
self.add_network("decoder", "a", self.decoder_a())
self.add_network("decoder", "b", self.decoder_b())
self.add_network("encoder", None, self.encoder())
logger.debug("Added networks")
def encoder(self):
""" Unbalanced Encoder """
kwargs = dict(kernel_initializer=self.kernel_initializer)
encoder_complexity = 128 if self.lowmem else self.config["complexity_encoder"]
dense_dim = 384 if self.lowmem else 512
dense_shape = self.input_shape[0] // 16
input_ = Input(shape=self.input_shape)
var_x = input_
var_x = self.blocks.conv(var_x, encoder_complexity, use_instance_norm=True, **kwargs)
var_x = self.blocks.conv(var_x, encoder_complexity * 2, use_instance_norm=True, **kwargs)
var_x = self.blocks.conv(var_x, encoder_complexity * 4, **kwargs)
var_x = self.blocks.conv(var_x, encoder_complexity * 6, **kwargs)
var_x = self.blocks.conv(var_x, encoder_complexity * 8, **kwargs)
var_x = Dense(self.encoder_dim,
kernel_initializer=self.kernel_initializer)(Flatten()(var_x))
var_x = Dense(dense_shape * dense_shape * dense_dim,
kernel_initializer=self.kernel_initializer)(var_x)
var_x = Reshape((dense_shape, dense_shape, dense_dim))(var_x)
return KerasModel(input_, var_x)
def decoder_a(self):
""" Decoder for side A """
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
decoder_complexity = 320 if self.lowmem else self.config["complexity_decoder_a"]
dense_dim = 384 if self.lowmem else 512
decoder_shape = self.input_shape[0] // 16
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
var_x = input_
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
var_x = SpatialDropout2D(0.25)(var_x)
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
if self.lowmem:
var_x = SpatialDropout2D(0.15)(var_x)
else:
var_x = SpatialDropout2D(0.25)(var_x)
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = input_
var_y = self.blocks.upscale(var_y, decoder_complexity)
var_y = self.blocks.upscale(var_y, decoder_complexity)
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)
def decoder_b(self):
""" Decoder for side B """
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
dense_dim = 384 if self.lowmem else self.config["complexity_decoder_b"]
decoder_complexity = 384 if self.lowmem else 512
decoder_shape = self.input_shape[0] // 16
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
var_x = input_
if self.lowmem:
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
var_x = self.blocks.upscale(var_x, decoder_complexity // 8, **kwargs)
else:
var_x = self.blocks.upscale(var_x, decoder_complexity,
res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, decoder_complexity,
kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, decoder_complexity,
res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, decoder_complexity,
kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, decoder_complexity // 2,
res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, decoder_complexity // 2,
kernel_initializer=self.kernel_initializer)
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = input_
var_y = self.blocks.upscale(var_y, decoder_complexity)
if not self.lowmem:
var_y = self.blocks.upscale(var_y, decoder_complexity)
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
if self.lowmem:
var_y = self.blocks.upscale(var_y, decoder_complexity // 8)
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)

View file

@ -0,0 +1,83 @@
#!/usr/bin/env python3
""" Original - VillainGuy model
Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
Adapted from a model by VillainGuy (https://github.com/VillainGuy) """
from keras.initializers import RandomNormal
from keras.layers import add, Conv2D, Dense, Flatten, Input, Reshape
from keras.models import Model as KerasModel
from lib.model.layers import PixelShuffler
from .original import logger, Model as OriginalModel
class Model(OriginalModel):
""" Villain Faceswap Model """
def __init__(self, *args, **kwargs):
logger.debug("Initializing %s: (args: %s, kwargs: %s",
self.__class__.__name__, args, kwargs)
kwargs["input_shape"] = (128, 128, 3)
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
self.kernel_initializer = RandomNormal(0, 0.02)
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
def encoder(self):
""" Encoder Network """
kwargs = dict(kernel_initializer=self.kernel_initializer)
input_ = Input(shape=self.input_shape)
in_conv_filters = self.input_shape[0]
if self.input_shape[0] > 128:
in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
dense_shape = self.input_shape[0] // 16
var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs)
tmp_x = var_x
res_cycles = 8 if self.config.get("lowmem", False) else 16
for _ in range(res_cycles):
nn_x = self.blocks.res_block(var_x, 128, **kwargs)
var_x = nn_x
# consider adding scale before this layer to scale the residual chain
var_x = add([var_x, tmp_x])
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = PixelShuffler()(var_x)
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = PixelShuffler()(var_x)
var_x = self.blocks.conv(var_x, 128, **kwargs)
var_x = self.blocks.conv_sep(var_x, 256, **kwargs)
var_x = self.blocks.conv(var_x, 512, **kwargs)
if not self.config.get("lowmem", False):
var_x = self.blocks.conv_sep(var_x, 1024, **kwargs)
var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
var_x = self.blocks.upscale(var_x, 512, **kwargs)
return KerasModel(input_, var_x)
def decoder(self):
""" Decoder Network """
kwargs = dict(kernel_initializer=self.kernel_initializer)
decoder_shape = self.input_shape[0] // 8
input_ = Input(shape=(decoder_shape, decoder_shape, 512))
var_x = input_
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, 512, **kwargs)
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, 256, **kwargs)
var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs)
var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs)
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
outputs = [var_x]
if self.config.get("mask_type", None):
var_y = input_
var_y = self.blocks.upscale(var_y, 512)
var_y = self.blocks.upscale(var_y, 256)
var_y = self.blocks.upscale(var_y, self.input_shape[0])
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
outputs.append(var_y)
return KerasModel(input_, outputs=outputs)

View file

View file

@ -0,0 +1,576 @@
#!/usr/bin/env python3
""" Base Trainer Class for Faceswap
Trainers should be inherited from this class.
A training_opts dictionary can be set in the corresponding model.
Accepted values:
alignments: dict containing paths to alignments files for keys 'a' and 'b'
preview_scaling: How much to scale the preview out by
training_size: Size of the training images
coverage_ratio: Ratio of face to be cropped out for training
mask_type: Type of mask to use. See lib.model.masks for valid mask names.
Set to None for not used
no_logs: Disable tensorboard logging
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
no_flip: Don't perform a random flip on the image
"""
import logging
import os
import time
import cv2
import numpy as np
from tensorflow import keras as tf_keras
from lib.alignments import Alignments
from lib.faces_detect import DetectedFace
from lib.training_data import TrainingDataGenerator, stack_images
from lib.utils import get_folder, get_image_paths
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class TrainerBase():
""" Base Trainer """
def __init__(self, model, images, batch_size):
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
self.__class__.__name__, model, batch_size)
self.batch_size = batch_size
self.model = model
self.model.state.add_session_batchsize(batch_size)
self.images = images
self.process_training_opts()
self.batchers = {side: Batcher(side,
images[side],
self.model,
self.use_mask,
batch_size)
for side in images.keys()}
self.tensorboard = self.set_tensorboard()
self.samples = Samples(self.model,
self.use_mask,
self.model.training_opts["coverage_ratio"],
self.model.training_opts["preview_scaling"])
self.timelapse = Timelapse(self.model,
self.use_mask,
self.model.training_opts["coverage_ratio"],
self.batchers)
logger.debug("Initialized %s", self.__class__.__name__)
@property
def timestamp(self):
""" Standardised timestamp for loss reporting """
return time.strftime("%H:%M:%S")
@property
def landmarks_required(self):
""" Return True if Landmarks are required """
opts = self.model.training_opts
retval = bool(opts.get("mask_type", None) or opts["warp_to_landmarks"])
logger.debug(retval)
return retval
@property
def use_mask(self):
""" Return True if a mask is requested """
retval = bool(self.model.training_opts.get("mask_type", None))
logger.debug(retval)
return retval
def process_training_opts(self):
""" Override for processing model specific training options """
logger.debug(self.model.training_opts)
if self.landmarks_required:
landmarks = Landmarks(self.model.training_opts).landmarks
self.model.training_opts["landmarks"] = landmarks
def set_tensorboard(self):
""" Set up tensorboard callback """
if self.model.training_opts["no_logs"]:
logger.verbose("TensorBoard logging disabled")
return None
logger.debug("Enabling TensorBoard Logging")
tensorboard = dict()
for side in self.images.keys():
logger.debug("Setting up TensorBoard Logging. Side: %s", side)
log_dir = os.path.join(str(self.model.model_dir),
"{}_logs".format(self.model.name),
side,
"session_{}".format(self.model.state.session_id))
tbs = tf_keras.callbacks.TensorBoard(log_dir=log_dir,
histogram_freq=0, # Must be 0 or hangs
batch_size=self.batch_size,
write_graph=True,
write_grads=True)
tbs.set_model(self.model.predictors[side])
tensorboard[side] = tbs
logger.info("Enabled TensorBoard Logging")
return tensorboard
def print_loss(self, loss):
""" Override for specific model loss formatting """
output = list()
for side in sorted(list(loss.keys())):
display = ", ".join(["{}_{}: {:.5f}".format(self.model.state.loss_names[side][idx],
side.capitalize(),
this_loss)
for idx, this_loss in enumerate(loss[side])])
output.append(display)
print("[{}] [#{:05d}] {}, {}".format(
self.timestamp, self.model.iterations, output[0], output[1]), end='\r')
def train_one_step(self, viewer, timelapse_kwargs):
""" Train a batch """
logger.trace("Training one step: (iteration: %s)", self.model.iterations)
is_preview_iteration = False if viewer is None else True
loss = dict()
for side, batcher in self.batchers.items():
loss[side] = batcher.train_one_batch(is_preview_iteration)
if not is_preview_iteration:
continue
self.samples.images[side] = batcher.compile_sample(self.batch_size)
if timelapse_kwargs:
self.timelapse.get_sample(side, timelapse_kwargs)
self.model.state.increment_iterations()
for side, side_loss in loss.items():
self.store_history(side, side_loss)
self.log_tensorboard(side, side_loss)
self.print_loss(loss)
if viewer is not None:
viewer(self.samples.show_sample(),
"Training - 'S': Save Now. 'ENTER': Save and Quit")
if timelapse_kwargs is not None:
self.timelapse.output_timelapse()
def store_history(self, side, loss):
""" Store the history of this step """
logger.trace("Updating loss history: '%s'", side)
self.model.history[side].append(loss[0]) # Either only loss or total loss
logger.trace("Updated loss history: '%s'", side)
def log_tensorboard(self, side, loss):
""" Log loss to TensorBoard log """
if not self.tensorboard:
return
logger.trace("Updating TensorBoard log: '%s'", side)
logs = {log[0]: log[1]
for log in zip(self.model.state.loss_names[side], loss)}
self.tensorboard[side].on_batch_end(self.model.state.iterations, logs)
logger.trace("Updated TensorBoard log: '%s'", side)
def clear_tensorboard(self):
""" Indicate training end to Tensorboard """
if not self.tensorboard:
return
for side, tensorboard in self.tensorboard.items():
logger.debug("Ending Tensorboard. Side: '%s'", side)
tensorboard.on_train_end(None)
class Batcher():
""" Batch images from a single side """
def __init__(self, side, images, model, use_mask, batch_size):
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)",
self.__class__.__name__, side, len(images), batch_size)
self.model = model
self.use_mask = use_mask
self.side = side
self.target = None
self.samples = None
self.mask = None
self.feed = self.load_generator().minibatch_ab(images, batch_size, self.side)
self.timelapse_feed = None
def load_generator(self):
""" Pass arguments to TrainingDataGenerator and return object """
logger.debug("Loading generator: %s", self.side)
input_size = self.model.input_shape[0]
output_size = self.model.output_shape[0]
logger.debug("input_size: %s, output_size: %s", input_size, output_size)
generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts)
return generator
def train_one_batch(self, is_preview_iteration):
""" Train a batch """
logger.trace("Training one step: (side: %s)", self.side)
batch = self.get_next(is_preview_iteration)
loss = self.model.predictors[self.side].train_on_batch(*batch)
loss = loss if isinstance(loss, list) else [loss]
return loss
def get_next(self, is_preview_iteration):
""" Return the next batch from the generator
Items should come out as: (warped, target [, mask]) """
batch = next(self.feed)
self.samples = batch[0] if is_preview_iteration else None
batch = batch[1:] # Remove full size samples from batch
if self.use_mask:
batch = self.compile_mask(batch)
self.target = batch[1] if is_preview_iteration else None
return batch
def compile_mask(self, batch):
""" Compile the mask into training data """
logger.trace("Compiling Mask: (side: '%s')", self.side)
mask = batch[-1]
retval = list()
for idx in range(len(batch) - 1):
image = batch[idx]
retval.append([image, mask])
return retval
def compile_sample(self, batch_size, samples=None, images=None):
""" Training samples to display in the viewer """
num_images = self.model.training_opts.get("preview_images", 14)
num_images = min(batch_size, num_images)
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
images = images if images is not None else self.target
samples = [samples[0:num_images]] if samples is not None else [self.samples[0:num_images]]
if self.use_mask:
retval = [tgt[0:num_images] for tgt in images]
else:
retval = [images[0:num_images]]
retval = samples + retval
return retval
def compile_timelapse_sample(self):
""" Timelapse samples """
batch = next(self.timelapse_feed)
samples = batch[0]
batch = batch[1:] # Remove full size samples from batch
batchsize = len(samples)
if self.use_mask:
batch = self.compile_mask(batch)
images = batch[1]
sample = self.compile_sample(batchsize, samples=samples, images=images)
return sample
def set_timelapse_feed(self, images, batchsize):
""" Set the timelapse dictionary """
logger.debug("Setting timelapse feed: (side: '%s', input_images: '%s', batchsize: %s)",
self.side, images, batchsize)
self.timelapse_feed = self.load_generator().minibatch_ab(images[:batchsize],
batchsize, self.side,
do_shuffle=False,
is_timelapse=True)
logger.debug("Set timelapse feed")
class Samples():
""" Display samples for preview and timelapse """
def __init__(self, model, use_mask, coverage_ratio, scaling=1.0):
logger.debug("Initializing %s: model: '%s', use_mask: %s, coverage_ratio: %s)",
self.__class__.__name__, model, use_mask, coverage_ratio)
self.model = model
self.use_mask = use_mask
self.images = dict()
self.coverage_ratio = coverage_ratio
self.scaling = scaling
logger.debug("Initialized %s", self.__class__.__name__)
def show_sample(self):
""" Display preview data """
logger.debug("Showing sample")
feeds = dict()
figures = dict()
headers = dict()
for side, samples in self.images.items():
faces = samples[1]
if self.model.input_shape[0] / faces.shape[1] != 1.0:
feeds[side] = self.resize_sample(side, faces, self.model.input_shape[0])
feeds[side] = feeds[side].reshape((-1, ) + self.model.input_shape)
else:
feeds[side] = faces
if self.use_mask:
mask = samples[-1]
feeds[side] = [feeds[side], mask]
preds = self.get_predictions(feeds["a"], feeds["b"])
for side, samples in self.images.items():
other_side = "a" if side == "b" else "b"
predictions = [preds["{}_{}".format(side, side)],
preds["{}_{}".format(other_side, side)]]
display = self.to_full_frame(side, samples, predictions)
headers[side] = self.get_headers(side, other_side, display[0].shape[1])
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
if self.images[side][0].shape[0] % 2 == 1:
figures[side] = np.concatenate([figures[side],
np.expand_dims(figures[side][0], 0)])
width = 4
side_cols = width // 2
if side_cols != 1:
headers = self.duplicate_headers(headers, side_cols)
header = np.concatenate([headers["a"], headers["b"]], axis=1)
figure = np.concatenate([figures["a"], figures["b"]], axis=0)
height = int(figure.shape[0] / width)
figure = figure.reshape((width, height) + figure.shape[1:])
figure = stack_images(figure)
figure = np.vstack((header, figure))
logger.debug("Compiled sample")
return np.clip(figure * 255, 0, 255).astype('uint8')
@staticmethod
def resize_sample(side, sample, target_size):
""" Resize samples where predictor expects different shape from processed image """
scale = target_size / sample.shape[1]
if scale == 1.0:
return sample
logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)",
side, sample.shape, target_size, scale)
interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint: disable=no-member
retval = np.array([cv2.resize(img, # pylint: disable=no-member
(target_size, target_size),
interpn)
for img in sample])
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
return retval
def get_predictions(self, feed_a, feed_b):
""" Return the sample predictions from the model """
logger.debug("Getting Predictions")
preds = dict()
preds["a_a"] = self.model.predictors["a"].predict(feed_a)
preds["b_a"] = self.model.predictors["b"].predict(feed_a)
preds["a_b"] = self.model.predictors["a"].predict(feed_b)
preds["b_b"] = self.model.predictors["b"].predict(feed_b)
# Get the returned image from predictors that emit multiple items
if not isinstance(preds["a_a"], np.ndarray):
for key, val in preds.items():
preds[key] = val[0]
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
return preds
def to_full_frame(self, side, samples, predictions):
""" Patch the images into the full frame """
logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)",
side, len(samples), [pred.shape for pred in predictions])
full, faces = samples[:2]
images = [faces] + predictions
full_size = full.shape[1]
target_size = int(full_size * self.coverage_ratio)
if target_size != full_size:
frame = self.frame_overlay(full, target_size, (0, 0, 255))
if self.use_mask:
images = self.compile_masked(images, samples[-1])
images = [self.resize_sample(side, image, target_size) for image in images]
if target_size != full_size:
images = [self.overlay_foreground(frame, image) for image in images]
if self.scaling != 1.0:
new_size = int(full_size * self.scaling)
images = [self.resize_sample(side, image, new_size) for image in images]
return images
@staticmethod
def frame_overlay(images, target_size, color):
""" Add roi frame to a backfround image """
logger.debug("full_size: %s, target_size: %s, color: %s",
images.shape[1], target_size, color)
new_images = list()
full_size = images.shape[1]
padding = (full_size - target_size) // 2
length = target_size // 4
t_l, b_r = (padding, full_size - padding)
for img in images:
cv2.rectangle(img, # pylint: disable=no-member
(t_l, t_l),
(t_l + length, t_l + length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(b_r, t_l),
(b_r - length, t_l + length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(b_r, b_r),
(b_r - length,
b_r - length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(t_l, b_r),
(t_l + length, b_r - length),
color,
3)
new_images.append(img)
retval = np.array(new_images)
logger.debug("Overlayed background. Shape: %s", retval.shape)
return retval
@staticmethod
def compile_masked(faces, masks):
""" Add the mask to the faces for masked preview """
retval = list()
masks3 = np.tile(1 - np.rint(masks), 3)
for mask in masks3:
mask[np.where((mask == [1., 1., 1.]).all(axis=2))] = [0., 0., 1.]
for previews in faces:
images = np.array([cv2.addWeighted(img, 1.0, # pylint: disable=no-member
masks3[idx], 0.3,
0)
for idx, img in enumerate(previews)])
retval.append(images)
logger.debug("masked shapes: %s", [faces.shape for faces in retval])
return retval
@staticmethod
def overlay_foreground(backgrounds, foregrounds):
""" Overlay the training images into the center of the background """
offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2
new_images = list()
for idx, img in enumerate(backgrounds):
img[offset:offset + foregrounds[idx].shape[0],
offset:offset + foregrounds[idx].shape[1]] = foregrounds[idx]
new_images.append(img)
retval = np.array(new_images)
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
return retval
def get_headers(self, side, other_side, width):
""" Set headers for images """
logger.debug("side: '%s', other_side: '%s', width: %s",
side, other_side, width)
side = side.upper()
other_side = other_side.upper()
height = int(64 * self.scaling)
total_width = width * 3
logger.debug("height: %s, total_width: %s", height, total_width)
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
texts = ["Target {}".format(side),
"{} > {}".format(side, side),
"{} > {}".format(side, other_side)]
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
font,
self.scaling,
1)[0]
for idx in range(len(texts))]
text_y = int((height + text_sizes[0][1]) / 2)
text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx
for idx in range(len(texts))]
logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s",
texts, text_sizes, text_x, text_y)
header_box = np.ones((height, total_width, 3), np.float32)
for idx, text in enumerate(texts):
cv2.putText(header_box, # pylint: disable=no-member
text,
(text_x[idx], text_y),
font,
self.scaling,
(0, 0, 0),
1,
lineType=cv2.LINE_AA) # pylint: disable=no-member
logger.debug("header_box.shape: %s", header_box.shape)
return header_box
@staticmethod
def duplicate_headers(headers, columns):
""" Duplicate headers for the number of columns displayed """
for side, header in headers.items():
duped = tuple([header for _ in range(columns)])
headers[side] = np.concatenate(duped, axis=1)
logger.debug("side: %s header.shape: %s", side, header.shape)
return headers
class Timelapse():
""" Create the timelapse """
def __init__(self, model, use_mask, coverage_ratio, batchers):
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
"batchers: '%s')", self.__class__.__name__, model, use_mask,
coverage_ratio, batchers)
self.samples = Samples(model, use_mask, coverage_ratio)
self.model = model
self.batchers = batchers
self.output_file = None
logger.debug("Initialized %s", self.__class__.__name__)
def get_sample(self, side, timelapse_kwargs):
""" Perform timelapse """
logger.debug("Getting timelapse samples: '%s'", side)
if not self.output_file:
self.setup(**timelapse_kwargs)
self.samples.images[side] = self.batchers[side].compile_timelapse_sample()
logger.debug("Got timelapse samples: '%s' - %s", side, len(self.samples.images[side]))
def setup(self, input_a=None, input_b=None, output=None):
""" Set the timelapse output folder """
logger.debug("Setting up timelapse")
if output is None:
output = str(get_folder(os.path.join(str(self.model.model_dir),
"{}_timelapse".format(self.model.name))))
self.output_file = str(output)
logger.debug("Timelapse output set to '%s'", self.output_file)
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
batchsize = min(len(images["a"]),
len(images["b"]),
self.model.training_opts.get("preview_images", 14))
for side, image_files in images.items():
self.batchers[side].set_timelapse_feed(image_files, batchsize)
logger.debug("Set up timelapse")
def output_timelapse(self):
""" Set the timelapse dictionary """
logger.debug("Ouputting timelapse")
image = self.samples.show_sample()
filename = os.path.join(self.output_file, str(int(time.time())) + ".jpg")
cv2.imwrite(filename, image) # pylint: disable=no-member
logger.debug("Created timelapse: '%s'", filename)
class Landmarks():
""" Set Landmarks for training into the model's training options"""
def __init__(self, training_opts):
logger.debug("Initializing %s: (training_opts: '%s')",
self.__class__.__name__, training_opts)
self.size = training_opts.get("training_size", 256)
self.paths = training_opts["alignments"]
self.landmarks = self.get_alignments()
logger.debug("Initialized %s", self.__class__.__name__)
def get_alignments(self):
""" Obtain the landmarks for each faceset """
landmarks = dict()
for side, fullpath in self.paths.items():
path, filename = os.path.split(fullpath)
filename, extension = os.path.splitext(filename)
serializer = extension[1:]
alignments = Alignments(
path,
filename=filename,
serializer=serializer)
landmarks[side] = self.transform_landmarks(alignments)
return landmarks
def transform_landmarks(self, alignments):
""" For each face transform landmarks and return """
landmarks = dict()
for _, faces, _, _ in alignments.yield_faces():
for face in faces:
detected_face = DetectedFace()
detected_face.from_alignment(face)
detected_face.load_aligned(None, size=self.size, align_eyes=False)
landmarks[detected_face.hash] = detected_face.aligned_landmarks
return landmarks

View file

@ -0,0 +1,4 @@
#!/usr/bin/env python3
""" Original Trainer """
from ._base import TrainerBase as Trainer

View file

@ -15,7 +15,6 @@ from lib.faces_detect import DetectedFace
from lib.multithreading import BackgroundGenerator, SpawnProcess from lib.multithreading import BackgroundGenerator, SpawnProcess
from lib.queue_manager import queue_manager from lib.queue_manager import queue_manager
from lib.utils import get_folder, get_image_paths, hash_image_file from lib.utils import get_folder, get_image_paths, hash_image_file
from plugins.plugin_loader import PluginLoader from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -43,7 +42,7 @@ class Convert():
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def process(self): def process(self):
""" Original & LowMem models go with Adjust or Masked converter """ Original & LowMem models go with converter
Note: GAN prediction outputs a mask + an image, while other Note: GAN prediction outputs a mask + an image, while other
predicts only an image. """ predicts only an image. """
@ -103,37 +102,19 @@ class Convert():
def load_model(self): def load_model(self):
""" Load the model requested for conversion """ """ Load the model requested for conversion """
model_name = self.args.trainer logger.debug("Loading Model")
model_dir = get_folder(self.args.model_dir) model_dir = get_folder(self.args.model_dir)
num_gpus = self.args.gpus model = PluginLoader.get_model(self.args.trainer)(model_dir, self.args.gpus, predict=True)
logger.debug("Loaded Model")
model = PluginLoader.get_model(model_name)(model_dir, num_gpus)
if not model.load(self.args.swap_model):
logger.error("Model Not Found! A valid model "
"must be provided to continue!")
exit(1)
return model return model
def load_converter(self, model): def load_converter(self, model):
""" Load the requested converter for conversion """ """ Load the requested converter for conversion """
args = self.args conv = self.args.converter
conv = args.converter
converter = PluginLoader.get_converter(conv)( converter = PluginLoader.get_converter(conv)(
model.converter(False), model.converter(self.args.swap_model),
trainer=args.trainer, model=model,
blur_size=args.blur_size, arguments=self.args)
seamless_clone=args.seamless_clone,
sharpen_image=args.sharpen_image,
mask_type=args.mask_type,
erosion_kernel_size=args.erosion_kernel_size,
match_histogram=args.match_histogram,
smooth_mask=args.smooth_mask,
avg_color_adjust=args.avg_color_adjust,
draw_transparent=args.draw_transparent)
return converter return converter
def prepare_images(self): def prepare_images(self):
@ -205,25 +186,13 @@ class Convert():
if not skip: if not skip:
for face in faces: for face in faces:
image = self.convert_one_face(converter, image, face) image = converter.patch_image(image, face)
filename = str(self.output_dir / Path(filename).name) filename = str(self.output_dir / Path(filename).name)
cv2.imwrite(filename, image) # pylint: disable=no-member cv2.imwrite(filename, image) # pylint: disable=no-member
except Exception as err: except Exception as err:
logger.error("Failed to convert image: '%s'. Reason: %s", filename, err) logger.error("Failed to convert image: '%s'. Reason: %s", filename, err)
raise raise
def convert_one_face(self, converter, image, face):
""" Perform the conversion on the given frame for a single face """
# TODO: This switch between 64 and 128 is a hack for now.
# We should have a separate cli option for size
size = 128 if (self.args.trainer.strip().lower()
in ('gan128', 'originalhighres')) else 64
image = converter.patch_image(image,
face,
size)
return image
class OptionalActions(): class OptionalActions():
""" Process the optional actions for convert """ """ Process the optional actions for convert """
@ -305,10 +274,8 @@ class OptionalActions():
class Legacy(): class Legacy():
""" Update legacy alignments: """ Update legacy alignments:
- Add frame dimensions
- Rotate landmarks and bounding boxes on legacy alignments - Rotate landmarks and bounding boxes on legacy alignments
and remove the 'r' parameter and remove the 'r' parameter
- Add face hashes to alignments file - Add face hashes to alignments file
""" """
def __init__(self, alignments, frames, faces_dir): def __init__(self, alignments, frames, faces_dir):
@ -319,15 +286,10 @@ class Legacy():
def process(self, faces_dir): def process(self, faces_dir):
""" Run the rotate alignments process """ """ Run the rotate alignments process """
no_dims = self.alignments.get_legacy_no_dims()
rotated = self.alignments.get_legacy_rotation() rotated = self.alignments.get_legacy_rotation()
hashes = self.alignments.get_legacy_no_hashes() hashes = self.alignments.get_legacy_no_hashes()
if not no_dims and not rotated and not hashes: if not rotated and not hashes:
return return
if no_dims:
logger.info("Legacy landmarks found. Adding frame dimensions...")
self.add_dimensions(no_dims)
self.alignments.save()
if rotated: if rotated:
logger.info("Legacy rotated frames found. Converting...") logger.info("Legacy rotated frames found. Converting...")
self.rotate_landmarks(rotated) self.rotate_landmarks(rotated)
@ -337,22 +299,14 @@ class Legacy():
self.add_hashes(hashes, faces_dir) self.add_hashes(hashes, faces_dir)
self.alignments.save() self.alignments.save()
def add_dimensions(self, no_dims):
""" Add width and height of original frame to alignments """
for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"):
if no_dim not in self.frames.keys():
continue
filename = self.frames[no_dim]
dims = cv2.imread(filename).shape[:2] # pylint: disable=no-member
self.alignments.add_dimensions(no_dim, dims)
def rotate_landmarks(self, rotated): def rotate_landmarks(self, rotated):
""" Rotate the landmarks """ """ Rotate the landmarks """
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
if rotate_item not in self.frames.keys(): frame = self.frames.get(rotate_item, None)
if frame is None:
logger.debug("Skipping missing frame: '%s'", rotate_item) logger.debug("Skipping missing frame: '%s'", rotate_item)
continue continue
self.alignments.rotate_existing_landmarks(rotate_item) self.alignments.rotate_existing_landmarks(rotate_item, frame)
def add_hashes(self, hashes, faces_dir): def add_hashes(self, hashes, faces_dir):
""" Add Face Hashes to the alignments file """ """ Add Face Hashes to the alignments file """

View file

@ -54,7 +54,7 @@ class Extract():
self.verify_output) self.verify_output)
def threaded_io(self, task, io_args=None): def threaded_io(self, task, io_args=None):
""" Load images in a background thread """ """ Perform I/O task in a background thread """
logger.debug("Threading task: (Task: '%s')", task) logger.debug("Threading task: (Task: '%s')", task)
io_args = tuple() if io_args is None else (io_args, ) io_args = tuple() if io_args is None else (io_args, )
if task == "load": if task == "load":
@ -211,7 +211,7 @@ class Extract():
self.threaded_io("reload", detected_faces) self.threaded_io("reload", detected_faces)
def align_face(self, faces, align_eyes, size, filename, padding=48): def align_face(self, faces, align_eyes, size, filename):
""" Align the detected face and add the destination file path """ """ Align the detected face and add the destination file path """
final_faces = list() final_faces = list()
image = faces["image"] image = faces["image"]
@ -221,11 +221,7 @@ class Extract():
detected_face = DetectedFace() detected_face = DetectedFace()
detected_face.from_dlib_rect(face, image) detected_face.from_dlib_rect(face, image)
detected_face.landmarksXY = landmarks[idx] detected_face.landmarksXY = landmarks[idx]
detected_face.frame_dims = image.shape[:2] detected_face.load_aligned(image, size=size, align_eyes=align_eyes)
detected_face.load_aligned(image,
size=size,
padding=padding,
align_eyes=align_eyes)
final_faces.append({"file_location": self.output_dir / Path(filename).stem, final_faces.append({"file_location": self.output_dir / Path(filename).stem,
"face": detected_face}) "face": detected_face})
faces["detected_faces"] = final_faces faces["detected_faces"] = final_faces
@ -262,7 +258,7 @@ class Plugins():
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def set_parallel_processing(self): def set_parallel_processing(self):
""" Set whether to run detect and align together or seperately """ """ Set whether to run detect and align together or separately """
detector_vram = self.detector.vram detector_vram = self.detector.vram
aligner_vram = self.aligner.vram aligner_vram = self.aligner.vram
gpu_stats = GPUStats() gpu_stats = GPUStats()
@ -356,11 +352,6 @@ class Plugins():
kwargs = {"in_queue": queue_manager.get_queue("load"), kwargs = {"in_queue": queue_manager.get_queue("load"),
"out_queue": out_queue} "out_queue": out_queue}
if self.args.detector == "mtcnn":
mtcnn_kwargs = self.detector.validate_kwargs(
self.get_mtcnn_kwargs())
kwargs["mtcnn_kwargs"] = mtcnn_kwargs
mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
self.process_detect = mp_func(self.detector.run, **kwargs) self.process_detect = mp_func(self.detector.run, **kwargs)
@ -384,14 +375,6 @@ class Plugins():
logger.debug("Launched Detector") logger.debug("Launched Detector")
def get_mtcnn_kwargs(self):
""" Add the mtcnn arguments into a kwargs dictionary """
mtcnn_threshold = [float(thr.strip())
for thr in self.args.mtcnn_threshold]
return {"minsize": self.args.mtcnn_minsize,
"threshold": mtcnn_threshold,
"factor": self.args.mtcnn_scalefactor}
def detect_faces(self, extract_pass="detect"): def detect_faces(self, extract_pass="detect"):
""" Detect faces from in an image """ """ Detect faces from in an image """
logger.debug("Running Detection. Pass: '%s'", extract_pass) logger.debug("Running Detection. Pass: '%s'", extract_pass)

View file

@ -180,7 +180,7 @@ class Images():
def load_disk_frames(self): def load_disk_frames(self):
""" Load frames from disk """ """ Load frames from disk """
logger.debug("Input is Seperate Frames. Loading images") logger.debug("Input is separate Frames. Loading images")
for filename in self.input_images: for filename in self.input_images:
logger.trace("Loading image: '%s'", filename) logger.trace("Loading image: '%s'", filename)
try: try:
@ -314,9 +314,10 @@ class BlurryFaceFilter(PostProcessAction): # pylint: disable=too-few-public-met
aligned_landmarks = face.aligned_landmarks aligned_landmarks = face.aligned_landmarks
resized_face = face.aligned_face resized_face = face.aligned_face
size = face.aligned["size"] size = face.aligned["size"]
padding = int(size * 0.1875)
feature_mask = extractor.get_feature_mask( feature_mask = extractor.get_feature_mask(
aligned_landmarks / size, aligned_landmarks / size,
size, 48) size, padding)
feature_mask = cv2.blur( # pylint: disable=no-member feature_mask = cv2.blur( # pylint: disable=no-member
feature_mask, (10, 10)) feature_mask, (10, 10))
isolated_face = cv2.multiply( # pylint: disable=no-member isolated_face = cv2.multiply( # pylint: disable=no-member

View file

@ -1,101 +1,77 @@
#!/usr/bin python3 #!/usr/bin python3
""" The optional GUI for faceswap """ """ The optional GUI for faceswap """
# NB: The GUI can't currently log as it is a wrapper for the python scripts, so don't import logging
# implement logging unless you can handle the conflicts
import os import os
import sys import sys
import tkinter as tk import tkinter as tk
from tkinter import messagebox, ttk from tkinter import messagebox, ttk
from lib.gui import (CliOptions, CurrentSession, CommandNotebook, Config, from lib.gui import (CliOptions, CommandNotebook, ConsoleOut, Session, DisplayNotebook,
ConsoleOut, DisplayNotebook, Images, ProcessWrapper, get_config, get_images, initialize_images, initialize_config, MainMenuBar,
StatusBar) ProcessWrapper, StatusBar)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class FaceswapGui(tk.Tk): class FaceswapGui(tk.Tk):
""" The Graphical User Interface """ """ The Graphical User Interface """
def __init__(self, pathscript): def __init__(self, pathscript):
tk.Tk.__init__(self) logger.debug("Initializing %s", self.__class__.__name__)
self.scaling_factor = self.get_scaling() super().__init__()
self.initialize_globals(pathscript)
self.set_geometry() self.set_geometry()
self.wrapper = ProcessWrapper(pathscript)
pathcache = os.path.join(pathscript, "lib", "gui", ".cache") get_images().delete_preview()
self.images = Images(pathcache)
self.cliopts = CliOptions()
self.session = CurrentSession()
statusbar = StatusBar(self)
self.wrapper = ProcessWrapper(statusbar,
self.session,
pathscript,
self.cliopts)
self.images.delete_preview()
self.protocol("WM_DELETE_WINDOW", self.close_app) self.protocol("WM_DELETE_WINDOW", self.close_app)
logger.debug("Initialized %s", self.__class__.__name__)
def initialize_globals(self, pathscript):
""" Initialize config and images global constants """
cliopts = CliOptions()
scaling_factor = self.get_scaling()
pathcache = os.path.join(pathscript, "lib", "gui", ".cache")
statusbar = StatusBar(self)
session = Session()
initialize_config(cliopts, scaling_factor, pathcache, statusbar, session)
initialize_images()
def get_scaling(self): def get_scaling(self):
""" Get the display DPI """ """ Get the display DPI """
dpi = self.winfo_fpixels("1i") dpi = self.winfo_fpixels("1i")
return dpi / 72.0 scaling = dpi / 72.0
logger.debug("dpi: %s, scaling: %s'", dpi, scaling)
return scaling
def set_geometry(self): def set_geometry(self):
""" Set GUI geometry """ """ Set GUI geometry """
self.tk.call("tk", "scaling", self.scaling_factor) scaling_factor = get_config().scaling_factor
width = int(1200 * self.scaling_factor) self.tk.call("tk", "scaling", scaling_factor)
height = int(640 * self.scaling_factor) width = int(1200 * scaling_factor)
height = int(640 * scaling_factor)
logger.debug("Geometry: %sx%s", width, height)
self.geometry("{}x{}+80+80".format(str(width), str(height))) self.geometry("{}x{}+80+80".format(str(width), str(height)))
def build_gui(self, debug_console): def build_gui(self, debug_console):
""" Build the GUI """ """ Build the GUI """
logger.debug("Building GUI")
self.title("Faceswap.py") self.title("Faceswap.py")
self.menu() self.configure(menu=MainMenuBar(self))
topcontainer, bottomcontainer = self.add_containers() topcontainer, bottomcontainer = self.add_containers()
CommandNotebook(topcontainer, CommandNotebook(topcontainer)
self.cliopts, DisplayNotebook(topcontainer)
self.wrapper.tk_vars, ConsoleOut(bottomcontainer, debug_console)
self.scaling_factor) logger.debug("Built GUI")
DisplayNotebook(topcontainer,
self.session,
self.wrapper.tk_vars,
self.scaling_factor)
ConsoleOut(bottomcontainer, debug_console, self.wrapper.tk_vars)
def menu(self):
""" Menu bar for loading and saving configs """
menubar = tk.Menu(self)
filemenu = tk.Menu(menubar, tearoff=0)
config = Config(self.cliopts, self.wrapper.tk_vars)
filemenu.add_command(label="Load full config...",
underline=0,
command=config.load)
filemenu.add_command(label="Save full config...",
underline=0,
command=config.save)
filemenu.add_separator()
filemenu.add_command(label="Reset all to default",
underline=0,
command=self.cliopts.reset)
filemenu.add_command(label="Clear all",
underline=0,
command=self.cliopts.clear)
filemenu.add_separator()
filemenu.add_command(label="Quit",
underline=0,
command=self.close_app)
menubar.add_cascade(label="File", menu=filemenu, underline=0)
self.config(menu=menubar)
def add_containers(self): def add_containers(self):
""" Add the paned window containers that """ Add the paned window containers that
hold each main area of the gui """ hold each main area of the gui """
logger.debug("Adding containers")
maincontainer = tk.PanedWindow(self, maincontainer = tk.PanedWindow(self,
sashrelief=tk.RAISED, sashrelief=tk.RAISED,
orient=tk.VERTICAL) orient=tk.VERTICAL)
@ -109,21 +85,26 @@ class FaceswapGui(tk.Tk):
bottomcontainer = ttk.Frame(maincontainer, height=150) bottomcontainer = ttk.Frame(maincontainer, height=150)
maincontainer.add(bottomcontainer) maincontainer.add(bottomcontainer)
logger.debug("Added containers")
return topcontainer, bottomcontainer return topcontainer, bottomcontainer
def close_app(self): def close_app(self):
""" Close Python. This is here because the graph """ Close Python. This is here because the graph
animation function continues to run even when animation function continues to run even when
tkinter has gone away """ tkinter has gone away """
logger.debug("Close Requested")
confirm = messagebox.askokcancel confirm = messagebox.askokcancel
confirmtxt = "Processes are still running. Are you sure...?" confirmtxt = "Processes are still running. Are you sure...?"
if (self.wrapper.tk_vars["runningtask"].get() tk_vars = get_config().tk_vars
if (tk_vars["runningtask"].get()
and not confirm("Close", confirmtxt)): and not confirm("Close", confirmtxt)):
logger.debug("Close Cancelled")
return return
if self.wrapper.tk_vars["runningtask"].get(): if tk_vars["runningtask"].get():
self.wrapper.task.terminate() self.wrapper.task.terminate()
self.images.delete_preview() get_images().delete_preview()
self.quit() self.quit()
logger.debug("Closed GUI")
exit() exit()

View file

@ -4,14 +4,18 @@
import logging import logging
import os import os
import sys import sys
import threading
from threading import Lock
from time import sleep
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
from keras.backend.tensorflow_backend import set_session from keras.backend.tensorflow_backend import set_session
from lib.utils import (get_folder, get_image_paths, set_system_verbosity, from lib.keypress import KBHit
Timelapse) from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import (get_folder, get_image_paths, set_system_verbosity)
from plugins.plugin_loader import PluginLoader from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -20,38 +24,48 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Train(): class Train():
""" The training process. """ """ The training process. """
def __init__(self, arguments): def __init__(self, arguments):
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
self.args = arguments self.args = arguments
self.timelapse = self.set_timelapse()
self.images = self.get_images() self.images = self.get_images()
self.stop = False self.stop = False
self.save_now = False self.save_now = False
self.preview_buffer = dict() self.preview_buffer = dict()
self.lock = threading.Lock() self.lock = Lock()
# this is so that you can enter case insensitive values for trainer self.trainer_name = self.args.trainer
trainer_name = self.args.trainer logger.debug("Initialized %s", self.__class__.__name__)
self.trainer_name = trainer_name
if trainer_name.lower() == "lowmem":
self.trainer_name = "LowMem"
self.timelapse = None
def process(self): def set_timelapse(self):
""" Call the training process object """ """ Set timelapse paths if requested """
logger.info("Training data directory: %s", self.args.model_dir) if (not self.args.timelapse_input_a and
set_system_verbosity(self.args.loglevel) not self.args.timelapse_input_b and
thread = self.start_thread() not self.args.timelapse_output):
return None
if not self.args.timelapse_input_a or not self.args.timelapse_input_b:
raise ValueError("To enable the timelapse, you have to supply "
"all the parameters (--timelapse-input-A and "
"--timelapse-input-B).")
if self.args.preview: for folder in (self.args.timelapse_input_a,
self.monitor_preview() self.args.timelapse_input_b,
else: self.args.timelapse_output):
self.monitor_console() if folder is not None and not os.path.isdir(folder):
raise ValueError("The Timelapse path '{}' does not exist".format(folder))
self.end_thread(thread) kwargs = {"input_a": self.args.timelapse_input_a,
"input_b": self.args.timelapse_input_b,
"output": self.args.timelapse_output}
logger.debug("Timelapse enabled: %s", kwargs)
return kwargs
def get_images(self): def get_images(self):
""" Check the image dirs exist, contain images and return the image """ Check the image dirs exist, contain images and return the image
objects """ objects """
images = [] logger.debug("Getting image paths")
for image_dir in [self.args.input_A, self.args.input_B]: images = dict()
for side in ("a", "b"):
image_dir = getattr(self.args, "input_{}".format(side))
if not os.path.isdir(image_dir): if not os.path.isdir(image_dir):
logger.error("Error: '%s' does not exist", image_dir) logger.error("Error: '%s' does not exist", image_dir)
exit(1) exit(1)
@ -60,30 +74,60 @@ class Train():
logger.error("Error: '%s' contains no images", image_dir) logger.error("Error: '%s' contains no images", image_dir)
exit(1) exit(1)
images.append(get_image_paths(image_dir)) images[side] = get_image_paths(image_dir)
logger.info("Model A Directory: %s", self.args.input_A) logger.info("Model A Directory: %s", self.args.input_a)
logger.info("Model B Directory: %s", self.args.input_B) logger.info("Model B Directory: %s", self.args.input_b)
logger.debug("Got image paths: %s", [(key, str(len(val)) + " images")
for key, val in images.items()])
return images return images
def process(self):
""" Call the training process object """
logger.debug("Starting Training Process")
logger.info("Training data directory: %s", self.args.model_dir)
set_system_verbosity(self.args.loglevel)
thread = self.start_thread()
# queue_manager.debug_monitor(1)
if self.args.preview:
err = self.monitor_preview(thread)
else:
err = self.monitor_console(thread)
self.end_thread(thread, err)
logger.debug("Completed Training Process")
def start_thread(self): def start_thread(self):
""" Put the training process in a thread so we can keep control """ """ Put the training process in a thread so we can keep control """
thread = threading.Thread(target=self.process_thread) logger.debug("Launching Trainer thread")
thread = MultiThread(target=self.training)
thread.start() thread.start()
logger.debug("Launched Trainer thread")
return thread return thread
def end_thread(self, thread): def end_thread(self, thread, err):
""" On termination output message and join thread back to main """ """ On termination output message and join thread back to main """
logger.info("Exit requested! The trainer will complete its current cycle, " logger.debug("Ending Training thread")
"save the models and quit (it can take up a couple of seconds " if err:
"depending on your training speed). If you want to kill it now, " msg = "Error caught! Exiting..."
"press Ctrl + c") log = logger.critical
else:
msg = ("Exit requested! The trainer will complete its current cycle, "
"save the models and quit (it can take up a couple of seconds "
"depending on your training speed). If you want to kill it now, "
"press Ctrl + c")
log = logger.info
log(msg)
self.stop = True self.stop = True
thread.join() thread.join()
sys.stdout.flush() sys.stdout.flush()
logger.debug("Ended Training thread")
def process_thread(self): def training(self):
""" The training process to be run inside a thread """ """ The training process to be run inside a thread """
try: try:
sleep(1) # Let preview instructions flush out to logger
logger.debug("Commencing Training")
logger.info("Loading data, this may take a while...") logger.info("Loading data, this may take a while...")
if self.args.allow_growth: if self.args.allow_growth:
@ -91,17 +135,12 @@ class Train():
model = self.load_model() model = self.load_model()
trainer = self.load_trainer(model) trainer = self.load_trainer(model)
self.timelapse = Timelapse.create_timelapse(
self.args.timelapse_input_A,
self.args.timelapse_input_B,
self.args.timelapse_output,
trainer)
self.run_training_cycle(model, trainer) self.run_training_cycle(model, trainer)
except KeyboardInterrupt: except KeyboardInterrupt:
try: try:
model.save_weights() logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting")
model.save_models()
trainer.clear_tensorboard()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Saving model weights has been cancelled!") logger.info("Saving model weights has been cancelled!")
exit(0) exit(0)
@ -110,105 +149,192 @@ class Train():
def load_model(self): def load_model(self):
""" Load the model requested for training """ """ Load the model requested for training """
logger.debug("Loading Model")
model_dir = get_folder(self.args.model_dir) model_dir = get_folder(self.args.model_dir)
model = PluginLoader.get_model(self.trainer_name)(model_dir, model = PluginLoader.get_model(self.trainer_name)(
self.args.gpus) model_dir,
self.args.gpus,
model.load(swapped=False) no_logs=self.args.no_logs,
warp_to_landmarks=self.args.warp_to_landmarks,
no_flip=self.args.no_flip,
training_image_size=self.image_size,
alignments_paths=self.alignments_paths,
preview_scale=self.args.preview_scale)
logger.debug("Loaded Model")
return model return model
@property
def image_size(self):
""" Get the training set image size for storing in model data """
image = cv2.imread(self.images["a"][0]) # pylint: disable=no-member
size = image.shape[0]
logger.debug("Training image size: %s", size)
return size
@property
def alignments_paths(self):
""" Set the alignments path to input dirs if not provided """
alignments_paths = dict()
for side in ("a", "b"):
alignments_path = getattr(self.args, "alignments_path_{}".format(side))
if not alignments_path:
image_path = getattr(self.args, "input_{}".format(side))
alignments_path = os.path.join(image_path, "alignments.json")
alignments_paths[side] = alignments_path
logger.debug("Alignments paths: %s", alignments_paths)
return alignments_paths
def load_trainer(self, model): def load_trainer(self, model):
""" Load the trainer requested for training """ """ Load the trainer requested for training """
images_a, images_b = self.images logger.debug("Loading Trainer")
trainer = PluginLoader.get_trainer(model.trainer)
trainer = PluginLoader.get_trainer(self.trainer_name)
trainer = trainer(model, trainer = trainer(model,
images_a, self.images,
images_b, self.args.batch_size)
self.args.batch_size, logger.debug("Loaded Trainer")
self.args.perceptual_loss)
return trainer return trainer
def run_training_cycle(self, model, trainer): def run_training_cycle(self, model, trainer):
""" Perform the training cycle """ """ Perform the training cycle """
logger.debug("Running Training Cycle")
if self.args.write_image or self.args.redirect_gui or self.args.preview:
display_func = self.show
else:
display_func = None
for iteration in range(0, self.args.iterations): for iteration in range(0, self.args.iterations):
logger.trace("Training iteration: %s", iteration)
save_iteration = iteration % self.args.save_interval == 0 save_iteration = iteration % self.args.save_interval == 0
viewer = self.show if save_iteration or self.save_now else None viewer = display_func if save_iteration or self.save_now else None
if save_iteration and self.timelapse is not None: timelapse = self.timelapse if save_iteration else None
self.timelapse.work() trainer.train_one_step(viewer, timelapse)
trainer.train_one_step(iteration, viewer)
if self.stop: if self.stop:
logger.debug("Stop received. Terminating")
break break
elif save_iteration: elif save_iteration:
model.save_weights() logger.trace("Save Iteration: (iteration: %s", iteration)
model.save_models()
elif self.save_now: elif self.save_now:
model.save_weights() logger.trace("Save Requested: (iteration: %s", iteration)
model.save_models()
self.save_now = False self.save_now = False
model.save_weights() logger.debug("Training cycle complete")
model.save_models()
trainer.clear_tensorboard()
self.stop = True self.stop = True
def monitor_preview(self): def monitor_preview(self, thread):
""" Generate the preview window and wait for keyboard input """ """ Generate the preview window and wait for keyboard input """
logger.info("Using live preview.\n" logger.debug("Launching Preview Monitor")
"Press 'ENTER' on the preview window to save and quit.\n" logger.info("R|=====================================================================")
"Press 'S' on the preview window to save model weights " logger.info("R|- Using live preview -")
"immediately") logger.info("R|- Press 'ENTER' on the preview window to save and quit -")
logger.info("R|- Press 'S' on the preview window to save model weights immediately -")
logger.info("R|=====================================================================")
err = False
while True: while True:
try: try:
with self.lock: with self.lock:
for name, image in self.preview_buffer.items(): for name, image in self.preview_buffer.items():
cv2.imshow(name, image) cv2.imshow(name, image) # pylint: disable=no-member
key = cv2.waitKey(1000) key = cv2.waitKey(1000) # pylint: disable=no-member
if self.stop:
logger.debug("Stop received")
break
if thread.has_error:
logger.debug("Thread error detected")
err = True
break
if key == ord("\n") or key == ord("\r"): if key == ord("\n") or key == ord("\r"):
logger.debug("Exit requested")
break break
if key == ord("s"): if key == ord("s"):
logger.info("Save requested")
self.save_now = True self.save_now = True
if self.stop:
break
except KeyboardInterrupt: except KeyboardInterrupt:
logger.debug("Keyboard Interrupt received")
break break
logger.debug("Closed Preview Monitor")
return err
def monitor_console(self, thread):
""" Monitor the console
NB: A custom function needs to be used for this because
input() blocks """
logger.debug("Launching Console Monitor")
logger.info("R|===============================================")
logger.info("R|- Starting -")
logger.info("R|- Press 'ENTER' to save and quit -")
logger.info("R|- Press 'S' to save model weights immediately -")
logger.info("R|===============================================")
keypress = KBHit(is_gui=self.args.redirect_gui)
err = False
while True:
try:
if thread.has_error:
logger.debug("Thread error detected")
err = True
break
if self.stop:
logger.debug("Stop received")
break
if keypress.kbhit():
key = keypress.getch()
if key in ("\n", "\r"):
logger.debug("Exit requested")
break
if key in ("s", "S"):
logger.info("Save requested")
self.save_now = True
except KeyboardInterrupt:
logger.debug("Keyboard Interrupt received")
break
keypress.set_normal_term()
logger.debug("Closed Console Monitor")
return err
@staticmethod @staticmethod
def monitor_console(): def keypress_monitor(keypress_queue):
""" Monitor the console for any input followed by enter or ctrl+c """ """ Monitor stdin for keypress """
# TODO: how to catch a specific key instead of Enter? while True:
# there isn't a good multiplatform solution: keypress_queue.put(sys.stdin.read(1))
# https://stackoverflow.com/questions/3523174
# TODO: Find a way to interrupt input() if the target iterations are
# reached. At the moment, setting a target iteration and using the -p
# flag is the only guaranteed way to exit the training loop on
# hitting target iterations.
logger.info("Starting. Press 'ENTER' to stop training and save model")
try:
input()
except KeyboardInterrupt:
pass
@staticmethod @staticmethod
def set_tf_allow_growth(): def set_tf_allow_growth():
""" Allow TensorFlow to manage VRAM growth """ """ Allow TensorFlow to manage VRAM growth """
# pylint: disable=no-member
logger.debug("Setting Tensorflow 'allow_growth' option")
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "0" config.gpu_options.visible_device_list = "0"
set_session(tf.Session(config=config)) set_session(tf.Session(config=config))
logger.debug("Set Tensorflow 'allow_growth' option")
def show(self, image, name=""): def show(self, image, name=""):
""" Generate the preview and write preview file output """ """ Generate the preview and write preview file output """
logger.trace("Updating preview: (name: %s)", name)
try: try:
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0])) scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
if self.args.write_image: if self.args.write_image:
img = "_sample_{}.jpg".format(name) logger.trace("Saving preview to disk")
img = "training_preview.jpg"
imgfile = os.path.join(scriptpath, img) imgfile = os.path.join(scriptpath, img)
cv2.imwrite(imgfile, image) cv2.imwrite(imgfile, image) # pylint: disable=no-member
logger.trace("Saved preview to: '%s'", img)
if self.args.redirect_gui: if self.args.redirect_gui:
img = ".gui_preview_{}.jpg".format(name) logger.trace("Generating preview for GUI")
img = ".gui_training_preview.jpg"
imgfile = os.path.join(scriptpath, "lib", "gui", imgfile = os.path.join(scriptpath, "lib", "gui",
".cache", "preview", img) ".cache", "preview", img)
cv2.imwrite(imgfile, image) cv2.imwrite(imgfile, image) # pylint: disable=no-member
logger.trace("Generated preview for GUI: '%s'", img)
if self.args.preview: if self.args.preview:
logger.trace("Generating preview for display: '%s'", name)
with self.lock: with self.lock:
self.preview_buffer[name] = image self.preview_buffer[name] = image
logger.trace("Generated preview for display: '%s'", name)
except Exception as err: except Exception as err:
logging.error("could not preview sample") logging.error("could not preview sample")
raise err raise err
logger.trace("Updated preview: (name: %s)", name)

View file

@ -2,7 +2,7 @@
""" Command Line Arguments for tools """ """ Command Line Arguments for tools """
from lib.cli import FaceSwapArgs from lib.cli import FaceSwapArgs
from lib.cli import (ContextFullPaths, DirFullPaths, from lib.cli import (ContextFullPaths, DirFullPaths,
FileFullPaths, SaveFileFullPaths) FileFullPaths, SaveFileFullPaths, Slider)
from lib.utils import _image_extensions from lib.utils import _image_extensions
@ -47,9 +47,9 @@ class AlignmentsArgs(FaceSwapArgs):
"\n\tfile." + output_opts + frames_dir + "\n\tfile." + output_opts + frames_dir +
"\n'missing-frames': Identify frames in the alignments file that do no " "\n'missing-frames': Identify frames in the alignments file that do no "
"\n\tappear within the frames folder/video." + output_opts + frames_dir + "\n\tappear within the frames folder/video." + output_opts + frames_dir +
"\n'legacy': This updates legacy alignments to the latest format by adding" "\n'legacy': This updates legacy alignments to the latest format by rotating"
"\n\tframe dimensions, rotating the landmarks and bounding boxes and adding" "\n\tthe landmarks and bounding boxes and adding face_hashes." +
"\n\tface_hashes" + frames_and_faces_dir + frames_and_faces_dir +
"\n'leftover-faces': Identify faces in the faces folder that do not exist in" "\n'leftover-faces': Identify faces in the faces folder that do not exist in"
"\n\tthe alignments file." + output_opts + faces_dir + "\n\tthe alignments file." + output_opts + faces_dir +
"\n'multi-faces': Identify where multiple faces exist within the alignments" "\n'multi-faces': Identify where multiple faces exist within the alignments"
@ -123,6 +123,13 @@ class AlignmentsArgs(FaceSwapArgs):
"\n\tdirectory)." "\n\tdirectory)."
"\n'move': Move the discovered items to a sub-folder within the source" "\n'move': Move the discovered items to a sub-folder within the source"
"\n\tdirectory."}) "\n\tdirectory."})
argument_list.append({"opts": ("-sz", "--size"),
"type": int,
"action": Slider,
"min_max": (128, 512),
"default": 256,
"rounding": 64,
"help": "The output size of extracted faces. (extract only)"})
argument_list.append({"opts": ("-ae", "--align-eyes"), argument_list.append({"opts": ("-ae", "--align-eyes"),
"action": "store_true", "action": "store_true",
"dest": "align_eyes", "dest": "align_eyes",
@ -409,6 +416,9 @@ class SortArgs(FaceSwapArgs):
"Default: hist"}) "Default: hist"})
argument_list.append({"opts": ('-t', '--ref_threshold'), argument_list.append({"opts": ('-t', '--ref_threshold'),
"action": Slider,
"min_max": (-1.0, 10.0),
"rounding": 2,
"type": float, "type": float,
"dest": 'min_threshold', "dest": 'min_threshold',
"default": -1.0, "default": -1.0,
@ -433,6 +443,9 @@ class SortArgs(FaceSwapArgs):
"hist 0.3"}) "hist 0.3"})
argument_list.append({"opts": ('-b', '--bins'), argument_list.append({"opts": ('-b', '--bins'),
"action": Slider,
"min_max": (1, 100),
"rounding": 1,
"type": int, "type": int,
"dest": 'num_bins', "dest": 'num_bins',
"default": 5, "default": 5,

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Tools for manipulating the alignments seralized file """ """ Tools for manipulating the alignments serialized file """
import logging import logging
import os import os
@ -32,7 +32,7 @@ class Check():
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def get_source_dir(self, arguments): def get_source_dir(self, arguments):
""" Set the correct source dir """ """ Set the correct source folder """
if hasattr(arguments, "faces_dir") and arguments.faces_dir: if hasattr(arguments, "faces_dir") and arguments.faces_dir:
self.type = "faces" self.type = "faces"
source_dir = arguments.faces_dir source_dir = arguments.faces_dir
@ -195,7 +195,7 @@ class Check():
os.rename(src, dst) os.rename(src, dst)
def move_faces(self, output_folder, items_output): def move_faces(self, output_folder, items_output):
""" Make additional subdirs for each face that appears """ Make additional subfolders for each face that appears
Enables easier manual sorting """ Enables easier manual sorting """
logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder) logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder)
for frame, idx in items_output: for frame, idx in items_output:
@ -239,7 +239,7 @@ class Draw():
legacy.process() legacy.process()
logger.info("[DRAW LANDMARKS]") # Tidy up cli output logger.info("[DRAW LANDMARKS]") # Tidy up cli output
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256,
align_eyes=self.arguments.align_eyes) align_eyes=self.arguments.align_eyes)
frames_drawn = 0 frames_drawn = 0
for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"): for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"):
@ -281,7 +281,7 @@ class Extract():
self.type = arguments.job.replace("extract-", "") self.type = arguments.job.replace("extract-", "")
self.faces_dir = arguments.faces_dir self.faces_dir = arguments.faces_dir
self.frames = Frames(arguments.frames_dir) self.frames = Frames(arguments.frames_dir)
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=arguments.size,
align_eyes=arguments.align_eyes) align_eyes=arguments.align_eyes)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -352,8 +352,7 @@ class Extract():
valid_faces = faces valid_faces = faces
else: else:
sizes = self.extracted_faces.get_roi_size_for_frame(frame) sizes = self.extracted_faces.get_roi_size_for_frame(frame)
valid_faces = [faces[idx] valid_faces = [faces[idx] for idx, size in enumerate(sizes)
for idx, size in enumerate(sizes)
if size >= self.extracted_faces.size] if size >= self.extracted_faces.size]
logger.trace("frame: '%s', total_faces: %s, valid_faces: %s", logger.trace("frame: '%s', total_faces: %s, valid_faces: %s",
frame, len(faces), len(valid_faces)) frame, len(faces), len(valid_faces))
@ -362,8 +361,6 @@ class Extract():
class Legacy(): class Legacy():
""" Update legacy alignments: """ Update legacy alignments:
- Add frame dimensions
- Rotate landmarks and bounding boxes on legacy alignments - Rotate landmarks and bounding boxes on legacy alignments
and remove the 'r' parameter and remove the 'r' parameter
- Add face hashes to alignments file - Add face hashes to alignments file
@ -383,16 +380,11 @@ class Legacy():
def process(self): def process(self):
""" Run the rotate alignments process """ """ Run the rotate alignments process """
no_dims = self.alignments.get_legacy_no_dims()
rotated = self.alignments.get_legacy_rotation() rotated = self.alignments.get_legacy_rotation()
hashes = self.alignments.get_legacy_no_hashes() hashes = self.alignments.get_legacy_no_hashes()
if (not self.frames or (not no_dims and not rotated)) and (not self.faces or not hashes): if (not self.frames or not rotated) and (not self.faces or not hashes):
return return
logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output
if no_dims and self.frames:
logger.info("Legacy landmarks found. Adding frame dimensions...")
self.add_dimensions(no_dims)
self.alignments.save()
if rotated and self.frames: if rotated and self.frames:
logger.info("Legacy rotated frames found. Converting...") logger.info("Legacy rotated frames found. Converting...")
self.rotate_landmarks(rotated) self.rotate_landmarks(rotated)
@ -402,20 +394,13 @@ class Legacy():
self.add_hashes(hashes) self.add_hashes(hashes)
self.alignments.save() self.alignments.save()
def add_dimensions(self, no_dims):
""" Add width and height of original frame to alignments """
for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"):
if no_dim not in self.frames.items.keys():
continue
dims = self.frames.load_image(no_dim).shape[:2]
self.alignments.add_dimensions(no_dim, dims)
def rotate_landmarks(self, rotated): def rotate_landmarks(self, rotated):
""" Rotate the landmarks """ """ Rotate the landmarks """
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
if rotate_item not in self.frames.items.keys(): frame = self.frames.get(rotate_item, None)
if frame is None:
continue continue
self.alignments.rotate_existing_landmarks(rotate_item) self.alignments.rotate_existing_landmarks(rotate_item, frame)
def add_hashes(self, hashes): def add_hashes(self, hashes):
""" Add Face Hashes to the alignments file """ """ Add Face Hashes to the alignments file """
@ -838,19 +823,19 @@ class Spatial():
"alignments -j extract -a %s -fr <path_to_frames_dir> -fc " "alignments -j extract -a %s -fr <path_to_frames_dir> -fc "
"<output_folder>", self.arguments.alignments_file) "<output_folder>", self.arguments.alignments_file)
# define shape normalization utility functions # Define shape normalization utility functions
@staticmethod @staticmethod
def normalize_shapes(shapes_im_coords): def normalize_shapes(shapes_im_coords):
""" Normalize a 2D or 3D shape """ """ Normalize a 2D or 3D shape """
logger.debug("Normalize shapes") logger.debug("Normalize shapes")
(num_pts, num_dims, _) = shapes_im_coords.shape (num_pts, num_dims, _) = shapes_im_coords.shape
# calc mean coords and subtract from shapes # Calculate mean coordinates and subtract from shapes
mean_coords = shapes_im_coords.mean(axis=0) mean_coords = shapes_im_coords.mean(axis=0)
shapes_centered = np.zeros(shapes_im_coords.shape) shapes_centered = np.zeros(shapes_im_coords.shape)
shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1]) shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1])
# calc scale factors and divide shapes # Calculate scale factors and divide shapes
scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0) scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0)
shapes_normalized = np.zeros(shapes_centered.shape) shapes_normalized = np.zeros(shapes_centered.shape)
shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1]) shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1])
@ -889,12 +874,12 @@ class Spatial():
landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1) landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1)
start = end start = end
end = start + landmarks.shape[2] end = start + landmarks.shape[2]
# store in one big array # Store in one big array
landmarks_all[:, :, start:end] = landmarks landmarks_all[:, :, start:end] = landmarks
# make sure we keep track of the mapping to the original frame # Make sure we keep track of the mapping to the original frame
self.mappings[start] = key self.mappings[start] = key
# normalize shapes # Normalize shapes
normalized_shape = self.normalize_shapes(landmarks_all) normalized_shape = self.normalize_shapes(landmarks_all)
self.normalized["landmarks"] = normalized_shape[0] self.normalized["landmarks"] = normalized_shape[0]
self.normalized["scale_factors"] = normalized_shape[1] self.normalized["scale_factors"] = normalized_shape[1]
@ -920,15 +905,15 @@ class Spatial():
(project and reconstruct) """ (project and reconstruct) """
logger.debug("Spatially Filter") logger.debug("Spatially Filter")
landmarks_norm = self.normalized["landmarks"] landmarks_norm = self.normalized["landmarks"]
# convert to matrix form # Convert to matrix form
landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T
# project onto shapes model and reconstruct # Project onto shapes model and reconstruct
landmarks_norm_table_rec = self.shapes_model.inverse_transform( landmarks_norm_table_rec = self.shapes_model.inverse_transform(
self.shapes_model.transform(landmarks_norm_table)) self.shapes_model.transform(landmarks_norm_table))
# convert back to shapes (numKeypoint, num_dims, numFrames) # Convert back to shapes (numKeypoint, num_dims, numFrames)
landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T, landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T,
[68, 2, landmarks_norm.shape[2]]) [68, 2, landmarks_norm.shape[2]])
# transform back to image coords # Transform back to image coords
retval = self.normalized_to_original(landmarks_norm_rec, retval = self.normalized_to_original(landmarks_norm_rec,
self.normalized["scale_factors"], self.normalized["scale_factors"],
self.normalized["mean_coords"]) self.normalized["mean_coords"])

View file

@ -294,7 +294,7 @@ class Interface():
def get_state_color(self): def get_state_color(self):
""" Return a color based on current state """ Return a color based on current state
white - View Mode white - View Mode
yellow - Edit Mide yellow - Edit Mode
red - Unsaved alignments """ red - Unsaved alignments """
color = (255, 255, 255) color = (255, 255, 255)
if self.state["edit"]["updated"]: if self.state["edit"]["updated"]:
@ -446,7 +446,7 @@ class Manual():
legacy.process() legacy.process()
logger.info("[MANUAL PROCESSING]") # Tidy up cli output logger.info("[MANUAL PROCESSING]") # Tidy up cli output
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256,
align_eyes=self.align_eyes) align_eyes=self.align_eyes)
self.interface = Interface(self.alignments, self.frames) self.interface = Interface(self.alignments, self.frames)
self.help = Help(self.interface) self.help = Help(self.interface)
@ -510,8 +510,8 @@ class Manual():
MS Windows doesn't appear to read the window state property MS Windows doesn't appear to read the window state property
properly, so we check for a negative key press. properly, so we check for a negative key press.
Conda (tested on Windows) doesn't sppear to read the window Conda (tested on Windows) doesn't appear to read the window
state property or negative key press properly, so we arbitarily state property or negative key press properly, so we arbitrarily
use another property """ use another property """
# pylint: disable=no-member # pylint: disable=no-member
logger.trace("Commencing closed window check") logger.trace("Commencing closed window check")
@ -790,7 +790,7 @@ class MouseHandler():
a_event = align_process.event a_event = align_process.event
align_process.start() align_process.start()
# Wait for Aligner to take init # Wait for Aligner to initialize
# The first ever load of the model for FAN has reportedly taken # The first ever load of the model for FAN has reportedly taken
# up to 3-4 minutes, hence high timeout. # up to 3-4 minutes, hence high timeout.
a_event.wait(300) a_event.wait(300)
@ -977,7 +977,8 @@ class MouseHandler():
self.interface.state["edit"]["updated"] = True self.interface.state["edit"]["updated"] = True
self.interface.state["edit"]["update_faces"] = True self.interface.state["edit"]["update_faces"] = True
def extracted_to_alignment(self, extract_data): @staticmethod
def extracted_to_alignment(extract_data):
""" Convert Extracted Tuple to Alignments data """ """ Convert Extracted Tuple to Alignments data """
alignment = dict() alignment = dict()
d_rect, landmarks = extract_data d_rect, landmarks = extract_data
@ -985,6 +986,5 @@ class MouseHandler():
alignment["w"] = d_rect.right() - d_rect.left() alignment["w"] = d_rect.right() - d_rect.left()
alignment["y"] = d_rect.top() alignment["y"] = d_rect.top()
alignment["h"] = d_rect.bottom() - d_rect.top() alignment["h"] = d_rect.bottom() - d_rect.top()
alignment["frame_dims"] = self.media["image"].shape[:2]
alignment["landmarksXY"] = landmarks alignment["landmarksXY"] = landmarks
return alignment return alignment

View file

@ -53,7 +53,7 @@ class AlignmentData(Alignments):
self.set_destination_format(destination_format) self.set_destination_format(destination_format)
def set_destination_format(self, destination_format): def set_destination_format(self, destination_format):
""" Standardise the destination format to the correct extension """ """ Standardize the destination format to the correct extension """
extensions = {".json": "json", extensions = {".json": "json",
".p": "pickle", ".p": "pickle",
".yml": "yaml", ".yml": "yaml",
@ -274,12 +274,11 @@ class Frames(MediaLoader):
class ExtractedFaces(): class ExtractedFaces():
""" Holds the extracted faces and matrix for """ Holds the extracted faces and matrix for
alignments """ alignments """
def __init__(self, frames, alignments, size=256, def __init__(self, frames, alignments, size=256, align_eyes=False):
padding=48, align_eyes=False):
logger.trace("Initializing %s: (size: %s, padding: %s, align_eyes: %s)", logger.trace("Initializing %s: (size: %s, padding: %s, align_eyes: %s)",
self.__class__.__name__, size, padding, align_eyes) self.__class__.__name__, size, align_eyes)
self.size = size self.size = size
self.padding = padding self.padding = int(size * 0.1875)
self.align_eyes = align_eyes self.align_eyes = align_eyes
self.alignments = alignments self.alignments = alignments
self.frames = frames self.frames = frames
@ -309,10 +308,7 @@ class ExtractedFaces():
self.current_frame, alignment) self.current_frame, alignment)
face = DetectedFace() face = DetectedFace()
face.from_alignment(alignment, image=image) face.from_alignment(alignment, image=image)
face.load_aligned(image, face.load_aligned(image, size=self.size, align_eyes=self.align_eyes)
size=self.size,
padding=self.padding,
align_eyes=self.align_eyes)
return face return face
def get_faces_in_frame(self, frame, update=False): def get_faces_in_frame(self, frame, update=False):

View file

@ -47,7 +47,7 @@ class Sort():
# Assigning default threshold values based on grouping method # Assigning default threshold values based on grouping method
if (self.args.final_process == "folders" if (self.args.final_process == "folders"
and self.args.min_threshold == -1.0): and self.args.min_threshold < 0.0):
method = self.args.group_method.lower() method = self.args.group_method.lower()
if method == 'face': if method == 'face':
self.args.min_threshold = 0.6 self.args.min_threshold = 0.6
@ -767,9 +767,9 @@ class Sort():
Normalize by pixel number to offset the effect Normalize by pixel number to offset the effect
of image size on pixel gradients & variance of image size on pixel gradients & variance
""" """
image = cv2.imread(image_file,cv2.IMREAD_GRAYSCALE) image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
blur_map = cv2.Laplacian(image, cv2.CV_32F) blur_map = cv2.Laplacian(image, cv2.CV_32F)
score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1])
return score return score
@staticmethod @staticmethod