mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* model_refactor (#571) * original model to new structure * IAE model to new structure * OriginalHiRes to new structure * Fix trainer for different resolutions * Initial config implementation * Configparse library added * improved training data loader * dfaker model working * Add logging to training functions * Non blocking input for cli training * Add error handling to threads. Add non-mp queues to queue_handler * Improved Model Building and NNMeta * refactor lib/models * training refactor. DFL H128 model Implementation * Dfaker - use hashes * Move timelapse. Remove perceptual loss arg * Update INSTALL.md. Add logger formatting. Update Dfaker training * DFL h128 partially ported * Add mask to dfaker (#573) * Remove old models. Add mask to dfaker * dfl mask. Make masks selectable in config (#575) * DFL H128 Mask. Mask type selectable in config. * remove gan_v2_2 * Creating Input Size config for models Creating Input Size config for models Will be used downstream in converters. Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes) * Add mask loss options to config * MTCNN options to config.ini. Remove GAN config. Update USAGE.md * Add sliders for numerical values in GUI * Add config plugins menu to gui. Validate config * Only backup model if loss has dropped. Get training working again * bugfixes * Standardise loss printing * GUI idle cpu fixes. Graph loss fix. * mutli-gpu logging bugfix * Merge branch 'staging' into train_refactor * backup state file * Crash protection: Only backup if both total losses have dropped * Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes) * Load and save model structure with weights * Slight code update * Improve config loader. Add subpixel opt to all models. Config to state * Show samples... wrong input * Remove AE topology. Add input/output shapes to State * Port original_villain (birb/VillainGuy) model to faceswap * Add plugin info to GUI config pages * Load input shape from state. IAE Config options. * Fix transform_kwargs. Coverage to ratio. Bugfix mask detection * Suppress keras userwarnings. Automate zoom. Coverage_ratio to model def. * Consolidation of converters & refactor (#574) * Consolidation of converters & refactor Initial Upload of alpha Items - consolidate convert_mased & convert_adjust into one converter -add average color adjust to convert_masked -allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size -allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size -eliminate redundant type conversions to avoid multiple round-off errors -refactor loops for vectorization/speed -reorganize for clarity & style changes TODO - bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now - issues with mask border giving black ring at zero erosion .. investigate - remove GAN ?? - test enlargment factors of umeyama standard face .. match to coverage factor - make enlargment factor a model parameter - remove convert_adjusted and referencing code when finished * Update Convert_Masked.py default blur size of 2 to match original... description of enlargement tests breakout matrxi scaling into def * Enlargment scale as a cli parameter * Update cli.py * dynamic interpolation algorithm Compute x & y scale factors from the affine matrix on the fly by QR decomp. Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image * input size input size from config * fix issues with <1.0 erosion * Update convert.py * Update Convert_Adjust.py more work on the way to merginf * Clean up help note on sharpen * cleanup seamless * Delete Convert_Adjust.py * Update umeyama.py * Update training_data.py * swapping * segmentation stub * changes to convert.str * Update masked.py * Backwards compatibility fix for models Get converter running * Convert: Move masks to class. bugfix blur_size some linting * mask fix * convert fixes - missing facehull_rect re-added - coverage to % - corrected coverage logic - cleanup of gui option ordering * Update cli.py * default for blur * Update masked.py * added preliminary low_mem version of OriginalHighRes model plugin * Code cleanup, minor fixes * Update masked.py * Update masked.py * Add dfl mask to convert * histogram fix & seamless location * update * revert * bugfix: Load actual configuration in gui * Standardize nn_blocks * Update cli.py * Minor code amends * Fix Original HiRes model * Add masks to preview output for mask trainers refactor trainer.__base.py * Masked trainers converter support * convert bugfix * Bugfix: Converter for masked (dfl/dfaker) trainers * Additional Losses (#592) * initial upload * Delete blur.py * default initializer = He instead of Glorot (#588) * Allow kernel_initializer to be overridable * Add ICNR Initializer option for upscale on all models. * Hopefully fixes RSoDs with original-highres model plugin * remove debug line * Original-HighRes model plugin Red Screen of Death fix, take #2 * Move global options to _base. Rename Villain model * clipnorm and res block biases * scale the end of res block * res block * dfaker pre-activation res * OHRES pre-activation * villain pre-activation * tabs/space in nn_blocks * fix for histogram with mask all set to zero * fix to prevent two networks with same name * GUI: Wider tooltips. Improve TQDM capture * Fix regex bug * Convert padding=48 to ratio of image size * Add size option to alignments tool extract * Pass through training image size to convert from model * Convert: Pull training coverage from model * convert: coverage, blur and erode to percent * simplify matrix scaling * ordering of sliders in train * Add matrix scaling to utils. Use interpolation in lib.aligner transform * masked.py Import get_matrix_scaling from utils * fix circular import * Update masked.py * quick fix for matrix scaling * testing thus for now * tqdm regex capture bugfix * Minor ammends * blur size cleanup * Remove coverage option from convert (Now cascades from model) * Implement convert for all model types * Add mask option and coverage option to all existing models * bugfix for model loading on convert * debug print removal * Bugfix for masks in dfl_h128 and iae * Update preview display. Add preview scaling to cli * mask notes * Delete training_data_v2.py errant file * training data variables * Fix timelapse function * Add new config items to state file for legacy purposes * Slight GUI tweak * Raise exception if problem with loaded model * Add Tensorboard support (Logs stored in model directory) * ICNR fix * loss bugfix * convert bugfix * Move ini files to config folder. Make TensorBoard optional * Fix training data for unbalanced inputs/outputs * Fix config "none" test * Keep helptext in .ini files when saving config from GUI * Remove frame_dims from alignments * Add no-flip and warp-to-landmarks cli options * Revert OHR to RC4_fix version * Fix lowmem mode on OHR model * padding to variable * Save models in parallel threads * Speed-up of res_block stability * Automated Reflection Padding * Reflect Padding as a training option Includes auto-calculation of proper padding shapes, input_shapes, output_shapes Flag included in config now * rest of reflect padding * Move TB logging to cli. Session info to state file * Add session iterations to state file * Add recent files to menu. GUI code tidy up * [GUI] Fix recent file list update issue * Add correct loss names to TensorBoard logs * Update live graph to use TensorBoard and remove animation * Fix analysis tab. GUI optimizations * Analysis Graph popup to Tensorboard Logs * [GUI] Bug fix for graphing for models with hypens in name * [GUI] Correctly split loss to tabs during training * [GUI] Add loss type selection to analysis graph * Fix store command name in recent files. Switch to correct tab on open * [GUI] Disable training graph when 'no-logs' is selected * Fix graphing race condition * rename original_hires model to unbalanced
This commit is contained in:
parent
584c41e005
commit
cd00859c40
94 changed files with 7435 additions and 4251 deletions
2
.github/ISSUE_TEMPLATE.md
vendored
2
.github/ISSUE_TEMPLATE.md
vendored
|
@ -1,5 +1,7 @@
|
||||||
**Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.**
|
**Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.**
|
||||||
|
|
||||||
|
**Please always attach your generated crash_report.log to any bug report**
|
||||||
|
|
||||||
## Expected behavior
|
## Expected behavior
|
||||||
|
|
||||||
*Describe, in some detail, what you are trying to do and what the output is that you expect from the program.*
|
*Describe, in some detail, what you are trying to do and what the output is that you expect from the program.*
|
||||||
|
|
9
.gitignore
vendored
9
.gitignore
vendored
|
@ -12,8 +12,9 @@
|
||||||
!Dockerfile*
|
!Dockerfile*
|
||||||
!requirements*
|
!requirements*
|
||||||
!.cache
|
!.cache
|
||||||
!lib
|
!config/
|
||||||
!lib/face_alignment
|
!lib/
|
||||||
|
!lib/*
|
||||||
!lib/gui
|
!lib/gui
|
||||||
!lib/gui/.cache/preview
|
!lib/gui/.cache/preview
|
||||||
!lib/gui/.cache/icons
|
!lib/gui/.cache/icons
|
||||||
|
@ -21,9 +22,9 @@
|
||||||
!plugins/
|
!plugins/
|
||||||
!plugins/*
|
!plugins/*
|
||||||
!plugins/extract/*
|
!plugins/extract/*
|
||||||
!plugins/model/*
|
!plugins/train/*
|
||||||
!tools
|
!tools
|
||||||
!tools/lib*
|
!tools/lib*
|
||||||
|
*.ini
|
||||||
*.pyc
|
*.pyc
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
56
INSTALL.md
56
INSTALL.md
|
@ -1,37 +1,37 @@
|
||||||
# Installing Faceswap
|
# Installing Faceswap
|
||||||
- [Installing Faceswap](#installing-faceswap)
|
- [Installing Faceswap](#installing-faceswap)
|
||||||
- [Prerequisites](#prerequisites)
|
- [Prerequisites](#prerequisites)
|
||||||
- [Hardware Requirements](#hardware-requirements)
|
- [Hardware Requirements](#hardware-requirements)
|
||||||
- [Supported operating systems](#supported-operating-systems)
|
- [Supported operating systems](#supported-operating-systems)
|
||||||
- [Important before you proceed](#important-before-you-proceed)
|
- [Important before you proceed](#important-before-you-proceed)
|
||||||
- [General Install Guide](#general-install-guide)
|
- [General Install Guide](#general-install-guide)
|
||||||
- [Installing dependencies](#installing-dependencies)
|
- [Installing dependencies](#installing-dependencies)
|
||||||
- [Getting the faceswap code](#getting-the-faceswap-code)
|
- [Getting the faceswap code](#getting-the-faceswap-code)
|
||||||
- [Setup](#setup)
|
- [Setup](#setup)
|
||||||
- [About some of the options](#about-some-of-the-options)
|
- [About some of the options](#about-some-of-the-options)
|
||||||
- [Run the project](#run-the-project)
|
- [Run the project](#run-the-project)
|
||||||
- [Notes](#notes)
|
- [Notes](#notes)
|
||||||
- [Windows Install Guide](#windows-install-guide)
|
- [Windows Install Guide](#windows-install-guide)
|
||||||
- [Prerequisites](#prerequisites-1)
|
- [Prerequisites](#prerequisites-1)
|
||||||
- [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015)
|
- [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015)
|
||||||
- [Cuda](#cuda)
|
- [Cuda](#cuda)
|
||||||
- [cuDNN](#cudnn)
|
- [cuDNN](#cudnn)
|
||||||
- [CMake](#cmake)
|
- [CMake](#cmake)
|
||||||
- [Anaconda](#anaconda)
|
- [Anaconda](#anaconda)
|
||||||
- [Git](#git)
|
- [Git](#git)
|
||||||
- [Setup](#setup-1)
|
- [Setup](#setup-1)
|
||||||
- [Anaconda](#anaconda-1)
|
- [Anaconda](#anaconda-1)
|
||||||
- [Set up a virtual environment](#set-up-a-virtual-environment)
|
- [Set up a virtual environment](#set-up-a-virtual-environment)
|
||||||
- [Entering your virtual environment](#entering-your-virtual-environment)
|
- [Entering your virtual environment](#entering-your-virtual-environment)
|
||||||
- [Faceswap](#faceswap)
|
- [Faceswap](#faceswap)
|
||||||
- [Easy install](#easy-install)
|
- [Easy install](#easy-install)
|
||||||
- [Manual install](#manual-install)
|
- [Manual install](#manual-install)
|
||||||
- [Running Faceswap](#running-faceswap)
|
- [Running Faceswap](#running-faceswap)
|
||||||
- [Create a desktop shortcut](#create-a-desktop-shortcut)
|
- [Create a desktop shortcut](#create-a-desktop-shortcut)
|
||||||
- [Updating faceswap](#updating-faceswap)
|
- [Updating faceswap](#updating-faceswap)
|
||||||
- [Dlib](#dlib)
|
- [Dlib](#dlib)
|
||||||
- [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support)
|
- [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support)
|
||||||
- [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support)
|
- [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support)
|
||||||
|
|
||||||
# Prerequisites
|
# Prerequisites
|
||||||
Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up.
|
Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up.
|
||||||
|
|
7
USAGE.md
7
USAGE.md
|
@ -34,6 +34,8 @@ You can see the full list of arguments for extracting via help flag. i.e.
|
||||||
python faceswap.py extract -h
|
python faceswap.py extract -h
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Some of the plugins have configurable options. You can find the config options in: `<faceswap_folder>\plugins\extract\config.ini`. Extract needs to have been run at least once to generate the config file
|
||||||
|
|
||||||
## TRAIN
|
## TRAIN
|
||||||
The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results.
|
The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results.
|
||||||
|
|
||||||
|
@ -51,6 +53,9 @@ You can see the full list of arguments for training via help flag. i.e.
|
||||||
python faceswap.py train -h
|
python faceswap.py train -h
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Some of the plugins have configurable options. You can find the config options in: `<faceswap_folder>\plugins\traom\config.ini`. Train needs to have been run at least once to generate the config file
|
||||||
|
|
||||||
|
|
||||||
## CONVERT
|
## CONVERT
|
||||||
Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture.
|
Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture.
|
||||||
|
|
||||||
|
@ -86,7 +91,7 @@ python tools.py effmpeg -h
|
||||||
```
|
```
|
||||||
|
|
||||||
## Extracting video frames with FFMPEG
|
## Extracting video frames with FFMPEG
|
||||||
Alternatively you can split a video into seperate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to seperate frames.
|
Alternatively you can split a video into separate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to separate frames.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png
|
ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png
|
||||||
|
|
|
@ -1,88 +0,0 @@
|
||||||
# PixelShuffler layer for Keras
|
|
||||||
# by t-ae
|
|
||||||
# https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
|
|
||||||
|
|
||||||
from keras.utils import conv_utils
|
|
||||||
from keras.engine.topology import Layer
|
|
||||||
import keras.backend as K
|
|
||||||
|
|
||||||
|
|
||||||
class PixelShuffler(Layer):
|
|
||||||
def __init__(self, size=(2, 2), data_format=None, **kwargs):
|
|
||||||
super(PixelShuffler, self).__init__(**kwargs)
|
|
||||||
self.data_format = K.normalize_data_format(data_format)
|
|
||||||
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
|
||||||
|
|
||||||
def call(self, inputs):
|
|
||||||
|
|
||||||
input_shape = K.int_shape(inputs)
|
|
||||||
if len(input_shape) != 4:
|
|
||||||
raise ValueError('Inputs should have rank ' +
|
|
||||||
str(4) +
|
|
||||||
'; Received input shape:', str(input_shape))
|
|
||||||
|
|
||||||
if self.data_format == 'channels_first':
|
|
||||||
batch_size, c, h, w = input_shape
|
|
||||||
if batch_size is None:
|
|
||||||
batch_size = -1
|
|
||||||
rh, rw = self.size
|
|
||||||
oh, ow = h * rh, w * rw
|
|
||||||
oc = c // (rh * rw)
|
|
||||||
|
|
||||||
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
|
|
||||||
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
|
|
||||||
out = K.reshape(out, (batch_size, oc, oh, ow))
|
|
||||||
return out
|
|
||||||
|
|
||||||
elif self.data_format == 'channels_last':
|
|
||||||
batch_size, h, w, c = input_shape
|
|
||||||
if batch_size is None:
|
|
||||||
batch_size = -1
|
|
||||||
rh, rw = self.size
|
|
||||||
oh, ow = h * rh, w * rw
|
|
||||||
oc = c // (rh * rw)
|
|
||||||
|
|
||||||
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
|
|
||||||
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
|
||||||
out = K.reshape(out, (batch_size, oh, ow, oc))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
|
||||||
|
|
||||||
if len(input_shape) != 4:
|
|
||||||
raise ValueError('Inputs should have rank ' +
|
|
||||||
str(4) +
|
|
||||||
'; Received input shape:', str(input_shape))
|
|
||||||
|
|
||||||
if self.data_format == 'channels_first':
|
|
||||||
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
|
|
||||||
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
|
|
||||||
channels = input_shape[1] // self.size[0] // self.size[1]
|
|
||||||
|
|
||||||
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
|
||||||
raise ValueError('channels of input and size are incompatible')
|
|
||||||
|
|
||||||
return (input_shape[0],
|
|
||||||
channels,
|
|
||||||
height,
|
|
||||||
width)
|
|
||||||
|
|
||||||
elif self.data_format == 'channels_last':
|
|
||||||
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
|
|
||||||
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
|
|
||||||
channels = input_shape[3] // self.size[0] // self.size[1]
|
|
||||||
|
|
||||||
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
|
||||||
raise ValueError('channels of input and size are incompatible')
|
|
||||||
|
|
||||||
return (input_shape[0],
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
channels)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {'size': self.size,
|
|
||||||
'data_format': self.data_format}
|
|
||||||
base_config = super(PixelShuffler, self).get_config()
|
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
|
@ -12,29 +12,6 @@ from lib.align_eyes import align_eyes as func_align_eyes, FACIAL_LANDMARKS_IDXS
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
MEAN_FACE_X = np.array([
|
|
||||||
0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483,
|
|
||||||
0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127,
|
|
||||||
0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122,
|
|
||||||
0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132,
|
|
||||||
0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127,
|
|
||||||
.551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532,
|
|
||||||
0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364,
|
|
||||||
0.490127, 0.42689])
|
|
||||||
|
|
||||||
MEAN_FACE_Y = np.array([
|
|
||||||
0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891,
|
|
||||||
0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625,
|
|
||||||
0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758,
|
|
||||||
0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758,
|
|
||||||
0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578,
|
|
||||||
0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192,
|
|
||||||
0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182,
|
|
||||||
0.831803, 0.824182])
|
|
||||||
|
|
||||||
LANDMARKS_2D = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
class Extract():
|
class Extract():
|
||||||
""" Based on the original https://www.reddit.com/r/deepfakes/
|
""" Based on the original https://www.reddit.com/r/deepfakes/
|
||||||
code sample + contribs """
|
code sample + contribs """
|
||||||
|
@ -42,8 +19,9 @@ class Extract():
|
||||||
def extract(self, image, face, size, align_eyes):
|
def extract(self, image, face, size, align_eyes):
|
||||||
""" Extract a face from an image """
|
""" Extract a face from an image """
|
||||||
logger.trace("size: %s. align_eyes: %s", size, align_eyes)
|
logger.trace("size: %s. align_eyes: %s", size, align_eyes)
|
||||||
|
padding = int(size * 0.1875)
|
||||||
alignment = get_align_mat(face, size, align_eyes)
|
alignment = get_align_mat(face, size, align_eyes)
|
||||||
extracted = self.transform(image, alignment, size, 48)
|
extracted = self.transform(image, alignment, size, padding)
|
||||||
logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment)
|
logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment)
|
||||||
return extracted, alignment
|
return extracted, alignment
|
||||||
|
|
||||||
|
@ -60,8 +38,9 @@ class Extract():
|
||||||
""" Transform Image """
|
""" Transform Image """
|
||||||
logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding)
|
logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding)
|
||||||
matrix = self.transform_matrix(mat, size, padding)
|
matrix = self.transform_matrix(mat, size, padding)
|
||||||
|
interpolators = get_matrix_scaling(matrix)
|
||||||
return cv2.warpAffine( # pylint: disable=no-member
|
return cv2.warpAffine( # pylint: disable=no-member
|
||||||
image, matrix, (size, size))
|
image, matrix, (size, size), flags=interpolators[0])
|
||||||
|
|
||||||
def transform_points(self, points, mat, size, padding=0):
|
def transform_points(self, points, mat, size, padding=0):
|
||||||
""" Transform points along matrix """
|
""" Transform points along matrix """
|
||||||
|
@ -144,12 +123,23 @@ class Extract():
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_matrix_scaling(mat):
|
||||||
|
""" Get the correct interpolator """
|
||||||
|
x_scale = np.sqrt(mat[0, 0] * mat[0, 0] + mat[0, 1] * mat[0, 1])
|
||||||
|
y_scale = (mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]) / x_scale
|
||||||
|
avg_scale = (x_scale + y_scale) * 0.5
|
||||||
|
if avg_scale >= 1.0:
|
||||||
|
interpolators = cv2.INTER_CUBIC, cv2.INTER_AREA # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
interpolators = cv2.INTER_AREA, cv2.INTER_CUBIC # pylint: disable=no-member
|
||||||
|
logger.trace("interpolator: %s, inverse interpolator: %s", interpolators[0], interpolators[1])
|
||||||
|
return interpolators
|
||||||
|
|
||||||
|
|
||||||
def get_align_mat(face, size, should_align_eyes):
|
def get_align_mat(face, size, should_align_eyes):
|
||||||
""" Return the alignment Matrix """
|
""" Return the alignment Matrix """
|
||||||
logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes)
|
logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes)
|
||||||
mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]),
|
mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), True)[0:2]
|
||||||
LANDMARKS_2D,
|
|
||||||
True)[0:2]
|
|
||||||
|
|
||||||
if should_align_eyes is False:
|
if should_align_eyes is False:
|
||||||
return mat_umeyama
|
return mat_umeyama
|
||||||
|
|
|
@ -270,33 +270,6 @@ class Alignments():
|
||||||
|
|
||||||
# << LEGACY FUNCTIONS >> #
|
# << LEGACY FUNCTIONS >> #
|
||||||
|
|
||||||
# < Original Frame Dimensions > #
|
|
||||||
# For dfaker and convert-adjust the original dimensions of a frame are
|
|
||||||
# required to calculate the transposed landmarks. As transposed landmarks
|
|
||||||
# will change on face size, we store original frame dimensions
|
|
||||||
# These were not previously required, so this adds the dimensions
|
|
||||||
# to the landmarks file
|
|
||||||
|
|
||||||
def get_legacy_no_dims(self):
|
|
||||||
""" Return a list of frames that do not contain the original frame
|
|
||||||
height and width attributes """
|
|
||||||
logger.debug("Getting alignments without frame_dims")
|
|
||||||
keys = list()
|
|
||||||
for key, val in self.data.items():
|
|
||||||
for alignment in val:
|
|
||||||
if "frame_dims" not in alignment.keys():
|
|
||||||
keys.append(key)
|
|
||||||
break
|
|
||||||
logger.debug("Got alignments without frame_dims: %s", len(keys))
|
|
||||||
return keys
|
|
||||||
|
|
||||||
def add_dimensions(self, frame_name, dimensions):
|
|
||||||
""" Backward compatability fix. Add frame dimensions
|
|
||||||
to alignments """
|
|
||||||
logger.trace("Adding dimensions: (frame: '%s', dimensions: %s)", frame_name, dimensions)
|
|
||||||
for face in self.get_faces_in_frame(frame_name):
|
|
||||||
face["frame_dims"] = dimensions
|
|
||||||
|
|
||||||
# < Rotation > #
|
# < Rotation > #
|
||||||
# The old rotation method would rotate the image to find a face, then
|
# The old rotation method would rotate the image to find a face, then
|
||||||
# store the rotated landmarks along with a rotation value to tell the
|
# store the rotated landmarks along with a rotation value to tell the
|
||||||
|
@ -319,20 +292,20 @@ class Alignments():
|
||||||
logger.debug("Got alignments containing legacy rotations: %s", len(keys))
|
logger.debug("Got alignments containing legacy rotations: %s", len(keys))
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
def rotate_existing_landmarks(self, frame_name):
|
def rotate_existing_landmarks(self, frame_name, frame):
|
||||||
""" Backwards compatability fix. Rotates the landmarks to
|
""" Backwards compatability fix. Rotates the landmarks to
|
||||||
their correct position and deletes r
|
their correct position and deletes r
|
||||||
|
|
||||||
NB: The original frame dimensions must be passed in otherwise
|
NB: The original frame must be passed in otherwise
|
||||||
the transformation cannot be performed """
|
the transformation cannot be performed """
|
||||||
logger.trace("Rotating existing landmarks for frame: '%s'", frame_name)
|
logger.trace("Rotating existing landmarks for frame: '%s'", frame_name)
|
||||||
|
dims = frame.shape[:2]
|
||||||
for face in self.get_faces_in_frame(frame_name):
|
for face in self.get_faces_in_frame(frame_name):
|
||||||
angle = face.get("r", 0)
|
angle = face.get("r", 0)
|
||||||
if not angle:
|
if not angle:
|
||||||
logger.trace("Landmarks do not require rotation: '%s'", frame_name)
|
logger.trace("Landmarks do not require rotation: '%s'", frame_name)
|
||||||
return
|
return
|
||||||
logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle)
|
logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle)
|
||||||
dims = face["frame_dims"]
|
|
||||||
r_mat = self.get_original_rotation_matrix(dims, angle)
|
r_mat = self.get_original_rotation_matrix(dims, angle)
|
||||||
rotate_landmarks(face, r_mat)
|
rotate_landmarks(face, r_mat)
|
||||||
del face["r"]
|
del face["r"]
|
||||||
|
|
403
lib/cli.py
403
lib/cli.py
|
@ -103,9 +103,41 @@ class ScriptExecutor():
|
||||||
safe_shutdown()
|
safe_shutdown()
|
||||||
|
|
||||||
|
|
||||||
class FullPaths(argparse.Action):
|
class Slider(argparse.Action): # pylint: disable=too-few-public-methods
|
||||||
|
""" Adds support for the GUI slider
|
||||||
|
|
||||||
|
An additional option 'min_max' must be provided containing tuple of min and max accepted
|
||||||
|
values.
|
||||||
|
|
||||||
|
'rounding' sets the decimal places for floats or the step interval for ints.
|
||||||
|
"""
|
||||||
|
def __init__(self, option_strings, dest, nargs=None, min_max=None, rounding=None, **kwargs):
|
||||||
|
if nargs is not None:
|
||||||
|
raise ValueError("nargs not allowed")
|
||||||
|
super().__init__(option_strings, dest, **kwargs)
|
||||||
|
self.min_max = min_max
|
||||||
|
self.rounding = rounding
|
||||||
|
|
||||||
|
def _get_kwargs(self):
|
||||||
|
names = ["option_strings",
|
||||||
|
"dest",
|
||||||
|
"nargs",
|
||||||
|
"const",
|
||||||
|
"default",
|
||||||
|
"type",
|
||||||
|
"choices",
|
||||||
|
"help",
|
||||||
|
"metavar",
|
||||||
|
"min_max", # Tuple containing min and max values of scale
|
||||||
|
"rounding"] # Decimal places to round floats to or step interval for ints
|
||||||
|
return [(name, getattr(self, name)) for name in names]
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
setattr(namespace, self.dest, values)
|
||||||
|
|
||||||
|
|
||||||
|
class FullPaths(argparse.Action): # pylint: disable=too-few-public-methods
|
||||||
""" Expand user- and relative-paths """
|
""" Expand user- and relative-paths """
|
||||||
# pylint: disable=too-few-public-methods
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, os.path.abspath(
|
setattr(namespace, self.dest, os.path.abspath(
|
||||||
os.path.expanduser(values)))
|
os.path.expanduser(values)))
|
||||||
|
@ -124,26 +156,23 @@ class FileFullPaths(FullPaths):
|
||||||
see lib/gui/utils.py FileHandler for current GUI filetypes
|
see lib/gui/utils.py FileHandler for current GUI filetypes
|
||||||
"""
|
"""
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
def __init__(self, option_strings, dest, nargs=None, filetypes=None,
|
def __init__(self, option_strings, dest, nargs=None, filetypes=None, **kwargs):
|
||||||
**kwargs):
|
|
||||||
super(FileFullPaths, self).__init__(option_strings, dest, **kwargs)
|
super(FileFullPaths, self).__init__(option_strings, dest, **kwargs)
|
||||||
if nargs is not None:
|
if nargs is not None:
|
||||||
raise ValueError("nargs not allowed")
|
raise ValueError("nargs not allowed")
|
||||||
self.filetypes = filetypes
|
self.filetypes = filetypes
|
||||||
|
|
||||||
def _get_kwargs(self):
|
def _get_kwargs(self):
|
||||||
names = [
|
names = ["option_strings",
|
||||||
"option_strings",
|
"dest",
|
||||||
"dest",
|
"nargs",
|
||||||
"nargs",
|
"const",
|
||||||
"const",
|
"default",
|
||||||
"default",
|
"type",
|
||||||
"type",
|
"choices",
|
||||||
"choices",
|
"help",
|
||||||
"help",
|
"metavar",
|
||||||
"metavar",
|
"filetypes"]
|
||||||
"filetypes"
|
|
||||||
]
|
|
||||||
return [(name, getattr(self, name)) for name in names]
|
return [(name, getattr(self, name)) for name in names]
|
||||||
|
|
||||||
|
|
||||||
|
@ -185,19 +214,17 @@ class ContextFullPaths(FileFullPaths):
|
||||||
self.filetypes = filetypes
|
self.filetypes = filetypes
|
||||||
|
|
||||||
def _get_kwargs(self):
|
def _get_kwargs(self):
|
||||||
names = [
|
names = ["option_strings",
|
||||||
"option_strings",
|
"dest",
|
||||||
"dest",
|
"nargs",
|
||||||
"nargs",
|
"const",
|
||||||
"const",
|
"default",
|
||||||
"default",
|
"type",
|
||||||
"type",
|
"choices",
|
||||||
"choices",
|
"help",
|
||||||
"help",
|
"metavar",
|
||||||
"metavar",
|
"filetypes",
|
||||||
"filetypes",
|
"action_option"]
|
||||||
"action_option"
|
|
||||||
]
|
|
||||||
return [(name, getattr(self, name)) for name in names]
|
return [(name, getattr(self, name)) for name in names]
|
||||||
|
|
||||||
|
|
||||||
|
@ -282,6 +309,13 @@ class FaceSwapArgs():
|
||||||
"help": "Path to store the logfile. Leave blank to store in the "
|
"help": "Path to store the logfile. Leave blank to store in the "
|
||||||
"faceswap folder",
|
"faceswap folder",
|
||||||
"default": None})
|
"default": None})
|
||||||
|
# This is a hidden argument to indicate that the GUI is being used,
|
||||||
|
# so the preview window should be redirected Accordingly
|
||||||
|
global_args.append({"opts": ("-gui", "--gui"),
|
||||||
|
"action": "store_true",
|
||||||
|
"dest": "redirect_gui",
|
||||||
|
"default": False,
|
||||||
|
"help": argparse.SUPPRESS})
|
||||||
return global_args
|
return global_args
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -342,11 +376,14 @@ class ExtractConvertArgs(FaceSwapArgs):
|
||||||
"dest": "alignments_path",
|
"dest": "alignments_path",
|
||||||
"help": "Optional path to an alignments file."})
|
"help": "Optional path to an alignments file."})
|
||||||
argument_list.append({"opts": ("-l", "--ref_threshold"),
|
argument_list.append({"opts": ("-l", "--ref_threshold"),
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (0.01, 0.99),
|
||||||
|
"rounding": 2,
|
||||||
"type": float,
|
"type": float,
|
||||||
"dest": "ref_threshold",
|
"dest": "ref_threshold",
|
||||||
"default": 0.6,
|
"default": 0.6,
|
||||||
"help": "Threshold for positive face "
|
"help": "Threshold for positive face recognition. For use with "
|
||||||
"recognition"})
|
"nfilter or filter. Lower values are stricter."})
|
||||||
argument_list.append({"opts": ("-n", "--nfilter"),
|
argument_list.append({"opts": ("-n", "--nfilter"),
|
||||||
"type": str,
|
"type": str,
|
||||||
"dest": "nfilter",
|
"dest": "nfilter",
|
||||||
|
@ -389,7 +426,7 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
"fallback."})
|
"fallback."})
|
||||||
argument_list.append({
|
argument_list.append({
|
||||||
"opts": ("-D", "--detector"),
|
"opts": ("-D", "--detector"),
|
||||||
"type": str,
|
"type": str.lower,
|
||||||
"choices": PluginLoader.get_available_extractors(
|
"choices": PluginLoader.get_available_extractors(
|
||||||
"detect"),
|
"detect"),
|
||||||
"default": "mtcnn",
|
"default": "mtcnn",
|
||||||
|
@ -404,7 +441,7 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
"\n\talignment to dlib"})
|
"\n\talignment to dlib"})
|
||||||
argument_list.append({
|
argument_list.append({
|
||||||
"opts": ("-A", "--aligner"),
|
"opts": ("-A", "--aligner"),
|
||||||
"type": str,
|
"type": str.lower,
|
||||||
"choices": PluginLoader.get_available_extractors(
|
"choices": PluginLoader.get_available_extractors(
|
||||||
"align"),
|
"align"),
|
||||||
"default": "fan",
|
"default": "fan",
|
||||||
|
@ -413,38 +450,6 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
"\n\tresource intensive, but less accurate."
|
"\n\tresource intensive, but less accurate."
|
||||||
"\n'fan': Face Alignment Network. Best aligner."
|
"\n'fan': Face Alignment Network. Best aligner."
|
||||||
"\n\tGPU heavy."})
|
"\n\tGPU heavy."})
|
||||||
argument_list.append({"opts": ("-mtms", "--mtcnn-minsize"),
|
|
||||||
"type": int,
|
|
||||||
"dest": "mtcnn_minsize",
|
|
||||||
"default": 20,
|
|
||||||
"help": "The minimum size of a face to be "
|
|
||||||
"accepted. Lower values use "
|
|
||||||
"significantly more VRAM. Minimum "
|
|
||||||
"value is 10. Default is 20 "
|
|
||||||
"(MTCNN detector only)"})
|
|
||||||
argument_list.append({"opts": ("-mtth", "--mtcnn-threshold"),
|
|
||||||
"nargs": "+",
|
|
||||||
"type": str,
|
|
||||||
"dest": "mtcnn_threshold",
|
|
||||||
"default": ["0.6", "0.7", "0.7"],
|
|
||||||
"help": "R|Three step threshold for face "
|
|
||||||
"detection. Should be\nthree decimal "
|
|
||||||
"numbers each less than 1. Eg:\n"
|
|
||||||
"'--mtcnn-threshold 0.6 0.7 0.7'.\n"
|
|
||||||
"1st stage: obtains face candidates.\n"
|
|
||||||
"2nd stage: refinement of face "
|
|
||||||
"candidates.\n3rd stage: further "
|
|
||||||
"refinement of face candidates.\n"
|
|
||||||
"Default is 0.6 0.7 0.7 "
|
|
||||||
"(MTCNN detector only)"})
|
|
||||||
argument_list.append({"opts": ("-mtsc", "--mtcnn-scalefactor"),
|
|
||||||
"type": float,
|
|
||||||
"dest": "mtcnn_scalefactor",
|
|
||||||
"default": 0.709,
|
|
||||||
"help": "The scale factor for the image "
|
|
||||||
"pyramid. Should be a decimal number "
|
|
||||||
"less than one. Default is 0.709 "
|
|
||||||
"(MTCNN detector only)"})
|
|
||||||
argument_list.append({"opts": ("-r", "--rotate-images"),
|
argument_list.append({"opts": ("-r", "--rotate-images"),
|
||||||
"type": str,
|
"type": str,
|
||||||
"dest": "rotate_images",
|
"dest": "rotate_images",
|
||||||
|
@ -458,13 +463,15 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
"exactly what angles to check"})
|
"exactly what angles to check"})
|
||||||
argument_list.append({"opts": ("-bt", "--blur-threshold"),
|
argument_list.append({"opts": ("-bt", "--blur-threshold"),
|
||||||
"type": float,
|
"type": float,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (0.0, 100.0),
|
||||||
|
"rounding": 1,
|
||||||
"dest": "blur_thresh",
|
"dest": "blur_thresh",
|
||||||
"default": None,
|
"default": 0.0,
|
||||||
"help": "Automatically discard images blurrier "
|
"help": "Automatically discard images blurrier than the specified "
|
||||||
"than the specified threshold. "
|
"threshold. Discarded images are moved into a \"blurry\" "
|
||||||
"Discarded images are moved into a "
|
"sub-folder. Lower values allow more blur. Set to 0.0 to "
|
||||||
"\"blurry\" sub-folder. Lower values "
|
"turn off."})
|
||||||
"allow more blur"})
|
|
||||||
argument_list.append({"opts": ("-mp", "--multiprocess"),
|
argument_list.append({"opts": ("-mp", "--multiprocess"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"default": False,
|
"default": False,
|
||||||
|
@ -476,12 +483,13 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
"otherwise this is automatic."})
|
"otherwise this is automatic."})
|
||||||
argument_list.append({"opts": ("-sz", "--size"),
|
argument_list.append({"opts": ("-sz", "--size"),
|
||||||
"type": int,
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (128, 512),
|
||||||
"default": 256,
|
"default": 256,
|
||||||
"help": "The output size of extracted faces. "
|
"rounding": 64,
|
||||||
"Make sure that the model you intend "
|
"help": "The output size of extracted faces. Make sure that the "
|
||||||
"to train supports your required "
|
"model you intend to train supports your required size. "
|
||||||
"size. This will only need to be "
|
"This will only need to be changed for hi-res models."})
|
||||||
"changed for hi-res models."})
|
|
||||||
argument_list.append({"opts": ("-s", "--skip-existing"),
|
argument_list.append({"opts": ("-s", "--skip-existing"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "skip_existing",
|
"dest": "skip_existing",
|
||||||
|
@ -512,13 +520,15 @@ class ExtractArgs(ExtractConvertArgs):
|
||||||
argument_list.append({"opts": ("-si", "--save-interval"),
|
argument_list.append({"opts": ("-si", "--save-interval"),
|
||||||
"dest": "save_interval",
|
"dest": "save_interval",
|
||||||
"type": int,
|
"type": int,
|
||||||
"default": None,
|
"action": Slider,
|
||||||
"help": "Automatically save the alignments file "
|
"min_max": (0, 1000),
|
||||||
"after a set amount of frames. Will "
|
"rounding": 10,
|
||||||
"only save at the end of extracting by "
|
"default": 0,
|
||||||
"default. WARNING: Don't interrupt the "
|
"help": "Automatically save the alignments file after a set amount "
|
||||||
"script when writing the file because "
|
"of frames. Will only save at the end of extracting by "
|
||||||
"it might get corrupted."})
|
"default. WARNING: Don't interrupt the script when writing "
|
||||||
|
"the file because it might get corrupted. Set to 0 to turn "
|
||||||
|
"off"})
|
||||||
return argument_list
|
return argument_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -552,57 +562,73 @@ class ConvertArgs(ExtractConvertArgs):
|
||||||
"specified, all faces will be "
|
"specified, all faces will be "
|
||||||
"converted"})
|
"converted"})
|
||||||
argument_list.append({"opts": ("-t", "--trainer"),
|
argument_list.append({"opts": ("-t", "--trainer"),
|
||||||
"type": str,
|
"type": str.lower,
|
||||||
# case sensitive because this is used to
|
|
||||||
# load a plug-in.
|
|
||||||
"choices": PluginLoader.get_available_models(),
|
"choices": PluginLoader.get_available_models(),
|
||||||
"default": PluginLoader.get_default_model(),
|
"default": PluginLoader.get_default_model(),
|
||||||
"help": "Select the trainer that was used to "
|
"help": "Select the trainer that was used to "
|
||||||
"create the model"})
|
"create the model"})
|
||||||
argument_list.append({"opts": ("-c", "--converter"),
|
argument_list.append({"opts": ("-c", "--converter"),
|
||||||
"type": str,
|
|
||||||
# case sensitive because this is used
|
|
||||||
# to load a plugin.
|
|
||||||
"choices": ("Masked", "Adjust"),
|
|
||||||
"default": "Masked",
|
|
||||||
"help": "Converter to use"})
|
|
||||||
argument_list.append({"opts": ("-b", "--blur-size"),
|
|
||||||
"type": int,
|
|
||||||
"default": 2,
|
|
||||||
"help": "Blur size. (Masked converter only)"})
|
|
||||||
argument_list.append({"opts": ("-e", "--erosion-kernel-size"),
|
|
||||||
"dest": "erosion_kernel_size",
|
|
||||||
"type": int,
|
|
||||||
"default": None,
|
|
||||||
"help": "Erosion kernel size. Positive values "
|
|
||||||
"apply erosion which reduces the edge "
|
|
||||||
"of the swapped face. Negative values "
|
|
||||||
"apply dilation which allows the "
|
|
||||||
"swapped face to cover more space. "
|
|
||||||
"(Masked converter only)"})
|
|
||||||
argument_list.append({"opts": ("-M", "--mask-type"),
|
|
||||||
# lowercase this, because it's just a
|
|
||||||
# string later on.
|
|
||||||
"type": str.lower,
|
"type": str.lower,
|
||||||
"dest": "mask_type",
|
"choices": PluginLoader.get_available_converters(),
|
||||||
"choices": ["rect",
|
"default": "masked",
|
||||||
"facehull",
|
"help": "Converter to use"})
|
||||||
"facehullandrect"],
|
argument_list.append({
|
||||||
"default": "facehullandrect",
|
"opts": ("-M", "--mask-type"),
|
||||||
"help": "Mask to use to replace faces. "
|
"type": str.lower,
|
||||||
"(Masked converter only)"})
|
"dest": "mask_type",
|
||||||
|
"choices": ["rect",
|
||||||
|
"ellipse",
|
||||||
|
"smoothed",
|
||||||
|
"facehull",
|
||||||
|
"facehull_rect",
|
||||||
|
"dfl",
|
||||||
|
"cnn"],
|
||||||
|
"default": "facehull_rect",
|
||||||
|
"help": "R|Mask to use to replace faces."
|
||||||
|
"\nrect: Rectangle around face."
|
||||||
|
"\nellipse: Oval around face."
|
||||||
|
"\nsmoothed: Rectangle around face with smoothing."
|
||||||
|
"\nfacehull: Face cutout based on landmarks."
|
||||||
|
"\nfacehull_rect: Rectangle around faces with facehull"
|
||||||
|
"\n\tbetween the edges of the face and the background."
|
||||||
|
"\ndfl: A Face Hull mask from DeepFaceLabs."
|
||||||
|
"\ncnn: Not yet implemented"})
|
||||||
|
argument_list.append({"opts": ("-b", "--blur-size"),
|
||||||
|
"type": float,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (0.0, 100.0),
|
||||||
|
"rounding": 2,
|
||||||
|
"default": 5.0,
|
||||||
|
"help": "Blur kernel size as a percentage of the swap area. Smooths "
|
||||||
|
"the transition between the swapped face and the background "
|
||||||
|
"image."})
|
||||||
|
argument_list.append({"opts": ("-e", "--erosion-size"),
|
||||||
|
"dest": "erosion_size",
|
||||||
|
"type": float,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (-100.0, 100.0),
|
||||||
|
"rounding": 2,
|
||||||
|
"default": 0.0,
|
||||||
|
"help": "Erosion kernel size as a percentage of the mask radius "
|
||||||
|
"area. Positive values apply erosion which reduces the size "
|
||||||
|
"of the swapped area. Negative values apply dilation which "
|
||||||
|
"increases the swapped area"})
|
||||||
|
argument_list.append({"opts": ("-g", "--gpus"),
|
||||||
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (1, 10),
|
||||||
|
"rounding": 1,
|
||||||
|
"default": 1,
|
||||||
|
"help": "Number of GPUs to use for conversion"})
|
||||||
argument_list.append({"opts": ("-sh", "--sharpen"),
|
argument_list.append({"opts": ("-sh", "--sharpen"),
|
||||||
"type": str.lower,
|
"type": str.lower,
|
||||||
"dest": "sharpen_image",
|
"dest": "sharpen_image",
|
||||||
"choices": ["bsharpen", "gsharpen"],
|
"choices": ["box_filter", "gaussian_filter"],
|
||||||
"default": None,
|
"default": None,
|
||||||
"help": "Use Sharpen Image. bsharpen for Box "
|
"help": "Sharpen the masked facial region of "
|
||||||
"Blur, gsharpen for Gaussian Blur "
|
"the converted images. Choice of filter "
|
||||||
"(Masked converter only)"})
|
"to use in sharpening process -- box"
|
||||||
argument_list.append({"opts": ("-g", "--gpus"),
|
"filter or gaussian filter."})
|
||||||
"type": int,
|
|
||||||
"default": 1,
|
|
||||||
"help": "Number of GPUs to use for conversion"})
|
|
||||||
argument_list.append({"opts": ("-fr", "--frame-ranges"),
|
argument_list.append({"opts": ("-fr", "--frame-ranges"),
|
||||||
"nargs": "+",
|
"nargs": "+",
|
||||||
"type": str,
|
"type": str,
|
||||||
|
@ -628,25 +654,25 @@ class ConvertArgs(ExtractConvertArgs):
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "seamless_clone",
|
"dest": "seamless_clone",
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Use cv2's seamless clone. "
|
"help": "Use cv2's seamless clone function to "
|
||||||
"(Masked converter only)"})
|
"remove extreme gradients at the mask "
|
||||||
|
"seam by smoothing colors."})
|
||||||
argument_list.append({"opts": ("-mh", "--match-histogram"),
|
argument_list.append({"opts": ("-mh", "--match-histogram"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "match_histogram",
|
"dest": "match_histogram",
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Use histogram matching. "
|
"help": "Adjust the histogram of each color "
|
||||||
"(Masked converter only)"})
|
"channel in the swapped reconstruction "
|
||||||
argument_list.append({"opts": ("-sm", "--smooth-mask"),
|
"to equal the histogram of the masked "
|
||||||
"action": "store_true",
|
"area in the orginal image"})
|
||||||
"dest": "smooth_mask",
|
|
||||||
"default": False,
|
|
||||||
"help": "Smooth mask (Adjust converter only)"})
|
|
||||||
argument_list.append({"opts": ("-aca", "--avg-color-adjust"),
|
argument_list.append({"opts": ("-aca", "--avg-color-adjust"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "avg_color_adjust",
|
"dest": "avg_color_adjust",
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Average color adjust. "
|
"help": "Adjust the mean of each color channel "
|
||||||
"(Adjust converter only)"})
|
" in the swapped reconstruction to "
|
||||||
|
"equal the mean of the masked area in "
|
||||||
|
"the orginal image"})
|
||||||
argument_list.append({"opts": ("-dt", "--draw-transparent"),
|
argument_list.append({"opts": ("-dt", "--draw-transparent"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "draw_transparent",
|
"dest": "draw_transparent",
|
||||||
|
@ -667,18 +693,38 @@ class TrainArgs(FaceSwapArgs):
|
||||||
argument_list = list()
|
argument_list = list()
|
||||||
argument_list.append({"opts": ("-A", "--input-A"),
|
argument_list.append({"opts": ("-A", "--input-A"),
|
||||||
"action": DirFullPaths,
|
"action": DirFullPaths,
|
||||||
"dest": "input_A",
|
"dest": "input_a",
|
||||||
"default": "input_A",
|
"default": "input_a",
|
||||||
"help": "Input directory. A directory "
|
"help": "Input directory. A directory "
|
||||||
"containing training images for face A. "
|
"containing training images for face A. "
|
||||||
"Defaults to 'input'"})
|
"Defaults to 'input'"})
|
||||||
argument_list.append({"opts": ("-B", "--input-B"),
|
argument_list.append({"opts": ("-B", "--input-B"),
|
||||||
"action": DirFullPaths,
|
"action": DirFullPaths,
|
||||||
"dest": "input_B",
|
"dest": "input_b",
|
||||||
"default": "input_B",
|
"default": "input_b",
|
||||||
"help": "Input directory. A directory "
|
"help": "Input directory. A directory "
|
||||||
"containing training images for face B. "
|
"containing training images for face B. "
|
||||||
"Defaults to 'input'"})
|
"Defaults to 'input'"})
|
||||||
|
argument_list.append({"opts": ("-ala", "--alignments-A"),
|
||||||
|
"action": FileFullPaths,
|
||||||
|
"filetypes": 'alignments',
|
||||||
|
"type": str,
|
||||||
|
"dest": "alignments_path_a",
|
||||||
|
"default": None,
|
||||||
|
"help": "Path to alignments file for training set A. Only required "
|
||||||
|
"if you are using a masked model or warp-to-landmarks is "
|
||||||
|
"enabled. Defaults to <input-A>/alignments.json if not "
|
||||||
|
"provided."})
|
||||||
|
argument_list.append({"opts": ("-alb", "--alignments-B"),
|
||||||
|
"action": FileFullPaths,
|
||||||
|
"filetypes": 'alignments',
|
||||||
|
"type": str,
|
||||||
|
"dest": "alignments_path_b",
|
||||||
|
"default": None,
|
||||||
|
"help": "Path to alignments file for training set B. Only required "
|
||||||
|
"if you are using a masked model or warp-to-landmarks is "
|
||||||
|
"enabled. Defaults to <input-B>/alignments.json if not "
|
||||||
|
"provided."})
|
||||||
argument_list.append({"opts": ("-m", "--model-dir"),
|
argument_list.append({"opts": ("-m", "--model-dir"),
|
||||||
"action": DirFullPaths,
|
"action": DirFullPaths,
|
||||||
"dest": "model_dir",
|
"dest": "model_dir",
|
||||||
|
@ -686,32 +732,51 @@ class TrainArgs(FaceSwapArgs):
|
||||||
"help": "Model directory. This is where the "
|
"help": "Model directory. This is where the "
|
||||||
"training data will be stored. "
|
"training data will be stored. "
|
||||||
"Defaults to 'model'"})
|
"Defaults to 'model'"})
|
||||||
argument_list.append({"opts": ("-s", "--save-interval"),
|
|
||||||
"type": int,
|
|
||||||
"dest": "save_interval",
|
|
||||||
"default": 100,
|
|
||||||
"help": "Sets the number of iterations before "
|
|
||||||
"saving the model"})
|
|
||||||
argument_list.append({"opts": ("-t", "--trainer"),
|
argument_list.append({"opts": ("-t", "--trainer"),
|
||||||
"type": str,
|
"type": str.lower,
|
||||||
"choices": PluginLoader.get_available_models(),
|
"choices": PluginLoader.get_available_models(),
|
||||||
"default": PluginLoader.get_default_model(),
|
"default": PluginLoader.get_default_model(),
|
||||||
"help": "Select which trainer to use, Use "
|
"help": "Select which trainer to use, Use "
|
||||||
"LowMem for cards with less than 2GB of "
|
"LowMem for cards with less than 2GB of "
|
||||||
"VRAM"})
|
"VRAM"})
|
||||||
|
argument_list.append({"opts": ("-s", "--save-interval"),
|
||||||
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (10, 1000),
|
||||||
|
"rounding": 10,
|
||||||
|
"dest": "save_interval",
|
||||||
|
"default": 100,
|
||||||
|
"help": "Sets the number of iterations before saving the model"})
|
||||||
argument_list.append({"opts": ("-bs", "--batch-size"),
|
argument_list.append({"opts": ("-bs", "--batch-size"),
|
||||||
"type": int,
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (2, 256),
|
||||||
|
"rounding": 2,
|
||||||
|
"dest": "batch_size",
|
||||||
"default": 64,
|
"default": 64,
|
||||||
"help": "Batch size, as a power of 2 "
|
"help": "Batch size, as a power of 2 (64, 128, 256, etc)"})
|
||||||
"(64, 128, 256, etc)"})
|
|
||||||
argument_list.append({"opts": ("-it", "--iterations"),
|
argument_list.append({"opts": ("-it", "--iterations"),
|
||||||
"type": int,
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (0, 5000000),
|
||||||
|
"rounding": 20000,
|
||||||
"default": 1000000,
|
"default": 1000000,
|
||||||
"help": "Length of training in iterations"})
|
"help": "Length of training in iterations."})
|
||||||
argument_list.append({"opts": ("-g", "--gpus"),
|
argument_list.append({"opts": ("-g", "--gpus"),
|
||||||
"type": int,
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (1, 10),
|
||||||
|
"rounding": 1,
|
||||||
"default": 1,
|
"default": 1,
|
||||||
"help": "Number of GPUs to use for training"})
|
"help": "Number of GPUs to use for training"})
|
||||||
|
argument_list.append({"opts": ("-ps", "--preview-scale"),
|
||||||
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"dest": "preview_scale",
|
||||||
|
"min_max": (25, 200),
|
||||||
|
"rounding": 25,
|
||||||
|
"default": 100,
|
||||||
|
"help": "Percentage amount to scale the preview by."})
|
||||||
argument_list.append({"opts": ("-p", "--preview"),
|
argument_list.append({"opts": ("-p", "--preview"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "preview",
|
"dest": "preview",
|
||||||
|
@ -724,20 +789,39 @@ class TrainArgs(FaceSwapArgs):
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Writes the training result to a file "
|
"help": "Writes the training result to a file "
|
||||||
"even on preview mode"})
|
"even on preview mode"})
|
||||||
argument_list.append({"opts": ("-pl", "--use-perceptual-loss"),
|
|
||||||
"action": "store_true",
|
|
||||||
"dest": "perceptual_loss",
|
|
||||||
"default": False,
|
|
||||||
"help": "Use perceptual loss while training"})
|
|
||||||
argument_list.append({"opts": ("-ag", "--allow-growth"),
|
argument_list.append({"opts": ("-ag", "--allow-growth"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "allow_growth",
|
"dest": "allow_growth",
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Sets allow_growth option of Tensorflow "
|
"help": "Sets allow_growth option of Tensorflow "
|
||||||
"to spare memory on some configs"})
|
"to spare memory on some configs"})
|
||||||
|
argument_list.append({"opts": ("-nl", "--no-logs"),
|
||||||
|
"action": "store_true",
|
||||||
|
"dest": "no_logs",
|
||||||
|
"default": False,
|
||||||
|
"help": "Disables TensorBoard logging. NB: Disabling logs means "
|
||||||
|
"that you will not be able to use the graph or analysis "
|
||||||
|
"for this session in the GUI."})
|
||||||
|
argument_list.append({"opts": ("-wl", "--warp-to-landmarks"),
|
||||||
|
"action": "store_true",
|
||||||
|
"dest": "warp_to_landmarks",
|
||||||
|
"default": False,
|
||||||
|
"help": "Warps training faces to closely matched Landmarks from the "
|
||||||
|
"opposite face-set rather than randomly warping the face. "
|
||||||
|
"This is the 'dfaker' way of doing warping. Alignments "
|
||||||
|
"files for both sets of faces must be provided if using "
|
||||||
|
"this option."})
|
||||||
|
argument_list.append({"opts": ("-nf", "--no-flip"),
|
||||||
|
"action": "store_true",
|
||||||
|
"dest": "no_flip",
|
||||||
|
"default": False,
|
||||||
|
"help": "To effectively learn, a random set of images are flipped "
|
||||||
|
"horizontally. Sometimes it is desirable for this not to "
|
||||||
|
"occur. Generally this should be left off except for "
|
||||||
|
"during 'fit training'."})
|
||||||
argument_list.append({"opts": ("-tia", "--timelapse-input-A"),
|
argument_list.append({"opts": ("-tia", "--timelapse-input-A"),
|
||||||
"action": DirFullPaths,
|
"action": DirFullPaths,
|
||||||
"dest": "timelapse_input_A",
|
"dest": "timelapse_input_a",
|
||||||
"default": None,
|
"default": None,
|
||||||
"help": "For if you want a timelapse: "
|
"help": "For if you want a timelapse: "
|
||||||
"The input folder for the timelapse. "
|
"The input folder for the timelapse. "
|
||||||
|
@ -748,7 +832,7 @@ class TrainArgs(FaceSwapArgs):
|
||||||
"--timelapse-input-B parameter."})
|
"--timelapse-input-B parameter."})
|
||||||
argument_list.append({"opts": ("-tib", "--timelapse-input-B"),
|
argument_list.append({"opts": ("-tib", "--timelapse-input-B"),
|
||||||
"action": DirFullPaths,
|
"action": DirFullPaths,
|
||||||
"dest": "timelapse_input_B",
|
"dest": "timelapse_input_b",
|
||||||
"default": None,
|
"default": None,
|
||||||
"help": "For if you want a timelapse: "
|
"help": "For if you want a timelapse: "
|
||||||
"The input folder for the timelapse. "
|
"The input folder for the timelapse. "
|
||||||
|
@ -765,13 +849,6 @@ class TrainArgs(FaceSwapArgs):
|
||||||
"If the input folders are supplied but "
|
"If the input folders are supplied but "
|
||||||
"no output folder, it will default to "
|
"no output folder, it will default to "
|
||||||
"your model folder /timelapse/"})
|
"your model folder /timelapse/"})
|
||||||
# This is a hidden argument to indicate that the GUI is being used,
|
|
||||||
# so the preview window should be redirected Accordingly
|
|
||||||
argument_list.append({"opts": ("-gui", "--gui"),
|
|
||||||
"action": "store_true",
|
|
||||||
"dest": "redirect_gui",
|
|
||||||
"default": False,
|
|
||||||
"help": argparse.SUPPRESS})
|
|
||||||
return argument_list
|
return argument_list
|
||||||
|
|
||||||
|
|
||||||
|
|
301
lib/config.py
Normal file
301
lib/config.py
Normal file
|
@ -0,0 +1,301 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Default configurations for faceswap
|
||||||
|
Extends out configparser funcionality
|
||||||
|
by checking for default config updates
|
||||||
|
and returning data in it's correct format """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
from configparser import ConfigParser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class FaceswapConfig():
|
||||||
|
""" Config Items """
|
||||||
|
def __init__(self, section):
|
||||||
|
""" Init Configuration """
|
||||||
|
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||||
|
self.configfile = self.get_config_file()
|
||||||
|
self.config = ConfigParser(allow_no_value=True)
|
||||||
|
self.defaults = OrderedDict()
|
||||||
|
self.config.optionxform = str
|
||||||
|
self.section = section
|
||||||
|
|
||||||
|
self.set_defaults()
|
||||||
|
self.handle_config()
|
||||||
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def set_defaults(self):
|
||||||
|
""" Override for plugin specific config defaults
|
||||||
|
|
||||||
|
Should be a series of self.add_section() and self.add_item() calls
|
||||||
|
|
||||||
|
e.g:
|
||||||
|
|
||||||
|
section = "sect_1"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="Section 1 Information")
|
||||||
|
|
||||||
|
self.add_item(section=section,
|
||||||
|
title="option_1",
|
||||||
|
datatype=bool,
|
||||||
|
default=False,
|
||||||
|
info="sect_1 option_1 information")
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_dict(self):
|
||||||
|
""" Collate global options and requested section into a dictionary
|
||||||
|
with the correct datatypes """
|
||||||
|
conf = dict()
|
||||||
|
for sect in ("global", self.section):
|
||||||
|
if sect not in self.config.sections():
|
||||||
|
continue
|
||||||
|
for key in self.config[sect]:
|
||||||
|
if key.startswith(("#", "\n")): # Skip comments
|
||||||
|
continue
|
||||||
|
conf[key] = self.get(sect, key)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def get(self, section, option):
|
||||||
|
""" Return a config item in it's correct format """
|
||||||
|
logger.debug("Getting config item: (section: '%s', option: '%s')", section, option)
|
||||||
|
datatype = self.defaults[section][option]["type"]
|
||||||
|
if datatype == bool:
|
||||||
|
func = self.config.getboolean
|
||||||
|
elif datatype == int:
|
||||||
|
func = self.config.getint
|
||||||
|
elif datatype == float:
|
||||||
|
func = self.config.getfloat
|
||||||
|
else:
|
||||||
|
func = self.config.get
|
||||||
|
retval = func(section, option)
|
||||||
|
if isinstance(retval, str) and retval.lower() == "none":
|
||||||
|
retval = None
|
||||||
|
logger.debug("Returning item: (type: %s, value: %s)", datatype, retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def get_config_file(self):
|
||||||
|
""" Return the config file from the calling folder """
|
||||||
|
dirname = os.path.dirname(sys.modules[self.__module__].__file__)
|
||||||
|
folder, fname = os.path.split(dirname)
|
||||||
|
retval = os.path.join(os.path.dirname(folder), "config", "{}.ini".format(fname))
|
||||||
|
logger.debug("Config File location: '%s'", retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def add_section(self, title=None, info=None):
|
||||||
|
""" Add a default section to config file """
|
||||||
|
logger.debug("Add section: (title: '%s', info: '%s')", title, info)
|
||||||
|
if None in (title, info):
|
||||||
|
raise ValueError("Default config sections must have a title and "
|
||||||
|
"information text")
|
||||||
|
self.defaults[title] = OrderedDict()
|
||||||
|
self.defaults[title]["helptext"] = info
|
||||||
|
|
||||||
|
def add_item(self, section=None, title=None, datatype=str,
|
||||||
|
default=None, info=None, rounding=None, min_max=None, choices=None):
|
||||||
|
""" Add a default item to a config section
|
||||||
|
|
||||||
|
For int or float values, rounding and min_max must be set
|
||||||
|
This is for the slider in the GUI. The min/max values are not enforced:
|
||||||
|
rounding: sets the decimal places for floats or the step interval for ints.
|
||||||
|
min_max: tuple of min and max accepted values
|
||||||
|
|
||||||
|
For str values choices can be set to validate input and create a combo box
|
||||||
|
in the GUI
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.debug("Add item: (section: '%s', title: '%s', datatype: '%s', default: '%s', "
|
||||||
|
"info: '%s', rounding: '%s', min_max: %s, choices: %s)",
|
||||||
|
section, title, datatype, default, info, rounding, min_max, choices)
|
||||||
|
|
||||||
|
choices = list() if not choices else choices
|
||||||
|
|
||||||
|
if None in (section, title, default, info):
|
||||||
|
raise ValueError("Default config items must have a section, "
|
||||||
|
"title, defult and "
|
||||||
|
"information text")
|
||||||
|
if not self.defaults.get(section, None):
|
||||||
|
raise ValueError("Section does not exist: {}".format(section))
|
||||||
|
if datatype not in (str, bool, float, int):
|
||||||
|
raise ValueError("'datatype' must be one of str, bool, float or "
|
||||||
|
"int: {} - {}".format(section, title))
|
||||||
|
if datatype in (float, int) and (rounding is None or min_max is None):
|
||||||
|
raise ValueError("'rounding' and 'min_max' must be set for numerical options")
|
||||||
|
if not isinstance(choices, (list, tuple)):
|
||||||
|
raise ValueError("'choices' must be a list or tuple")
|
||||||
|
self.defaults[section][title] = {"default": default,
|
||||||
|
"helptext": info,
|
||||||
|
"type": datatype,
|
||||||
|
"rounding": rounding,
|
||||||
|
"min_max": min_max,
|
||||||
|
"choices": choices}
|
||||||
|
|
||||||
|
def check_exists(self):
|
||||||
|
""" Check that a config file exists """
|
||||||
|
if not os.path.isfile(self.configfile):
|
||||||
|
logger.debug("Config file does not exist: '%s'", self.configfile)
|
||||||
|
return False
|
||||||
|
logger.debug("Config file exists: '%s'", self.configfile)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_default(self):
|
||||||
|
""" Generate a default config if it does not exist """
|
||||||
|
logger.debug("Creating default Config")
|
||||||
|
for section, items in self.defaults.items():
|
||||||
|
logger.debug("Adding section: '%s')", section)
|
||||||
|
self.insert_config_section(section, items["helptext"])
|
||||||
|
for item, opt in items.items():
|
||||||
|
logger.debug("Adding option: (item: '%s', opt: '%s'", item, opt)
|
||||||
|
if item == "helptext":
|
||||||
|
continue
|
||||||
|
self.insert_config_item(section,
|
||||||
|
item,
|
||||||
|
opt["default"],
|
||||||
|
opt)
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
|
def insert_config_section(self, section, helptext, config=None):
|
||||||
|
""" Insert a section into the config """
|
||||||
|
logger.debug("Inserting section: (section: '%s', helptext: '%s', config: '%s')",
|
||||||
|
section, helptext, config)
|
||||||
|
config = self.config if config is None else config
|
||||||
|
helptext = self.format_help(helptext, is_section=True)
|
||||||
|
config.add_section(section)
|
||||||
|
config.set(section, helptext)
|
||||||
|
logger.debug("Inserted section: '%s'", section)
|
||||||
|
|
||||||
|
def insert_config_item(self, section, item, default, option,
|
||||||
|
config=None):
|
||||||
|
""" Insert an item into a config section """
|
||||||
|
logger.debug("Inserting item: (section: '%s', item: '%s', default: '%s', helptext: '%s', "
|
||||||
|
"config: '%s')", section, item, default, option["helptext"], config)
|
||||||
|
config = self.config if config is None else config
|
||||||
|
helptext = option["helptext"]
|
||||||
|
helptext += self.set_helptext_choices(option)
|
||||||
|
helptext += "\n[Default: {}]".format(default)
|
||||||
|
helptext = self.format_help(helptext, is_section=False)
|
||||||
|
config.set(section, helptext)
|
||||||
|
config.set(section, item, str(default))
|
||||||
|
logger.debug("Inserted item: '%s'", item)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_helptext_choices(option):
|
||||||
|
""" Set the helptext choices """
|
||||||
|
choices = ""
|
||||||
|
if option["choices"]:
|
||||||
|
choices = "\nChoose from: {}".format(option["choices"])
|
||||||
|
elif option["type"] == bool:
|
||||||
|
choices = "\nChoose from: True, False"
|
||||||
|
elif option["type"] == int:
|
||||||
|
cmin, cmax = option["min_max"]
|
||||||
|
choices = "\nSelect an integer between {} and {}".format(cmin, cmax)
|
||||||
|
elif option["type"] == float:
|
||||||
|
cmin, cmax = option["min_max"]
|
||||||
|
choices = "\nSelect a decimal number between {} and {}".format(cmin, cmax)
|
||||||
|
return choices
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_help(helptext, is_section=False):
|
||||||
|
""" Format comments for default ini file """
|
||||||
|
logger.debug("Formatting help: (helptext: '%s', is_section: '%s')", helptext, is_section)
|
||||||
|
helptext = '# {}'.format(helptext.replace("\n", "\n# "))
|
||||||
|
if is_section:
|
||||||
|
helptext = helptext.upper()
|
||||||
|
else:
|
||||||
|
helptext = "\n{}".format(helptext)
|
||||||
|
logger.debug("formatted help: '%s'", helptext)
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
""" Load values from config """
|
||||||
|
logger.info("Loading config: '%s'", self.configfile)
|
||||||
|
self.config.read(self.configfile)
|
||||||
|
|
||||||
|
def save_config(self):
|
||||||
|
""" Save a config file """
|
||||||
|
logger.info("Updating config at: '%s'", self.configfile)
|
||||||
|
f_cfgfile = open(self.configfile, "w")
|
||||||
|
self.config.write(f_cfgfile)
|
||||||
|
f_cfgfile.close()
|
||||||
|
|
||||||
|
def validate_config(self):
|
||||||
|
""" Check for options in default config against saved config
|
||||||
|
and add/remove as appropriate """
|
||||||
|
logger.debug("Validating config")
|
||||||
|
if self.check_config_change():
|
||||||
|
self.add_new_config_items()
|
||||||
|
self.check_config_choices()
|
||||||
|
logger.debug("Validated config")
|
||||||
|
|
||||||
|
def add_new_config_items(self):
|
||||||
|
""" Add new items to the config file """
|
||||||
|
logger.debug("Updating config")
|
||||||
|
new_config = ConfigParser(allow_no_value=True)
|
||||||
|
for section, items in self.defaults.items():
|
||||||
|
self.insert_config_section(section, items["helptext"], new_config)
|
||||||
|
for item, opt in items.items():
|
||||||
|
if item == "helptext":
|
||||||
|
continue
|
||||||
|
if section not in self.config.sections():
|
||||||
|
logger.debug("Adding new config section: '%s'", section)
|
||||||
|
opt_value = opt["default"]
|
||||||
|
else:
|
||||||
|
opt_value = self.config[section].get(item, opt["default"])
|
||||||
|
self.insert_config_item(section,
|
||||||
|
item,
|
||||||
|
opt_value,
|
||||||
|
opt,
|
||||||
|
new_config)
|
||||||
|
self.config = new_config
|
||||||
|
self.config.optionxform = str
|
||||||
|
self.save_config()
|
||||||
|
logger.debug("Updated config")
|
||||||
|
|
||||||
|
def check_config_choices(self):
|
||||||
|
""" Check that config items are valid choices """
|
||||||
|
logger.debug("Checking config choices")
|
||||||
|
for section, items in self.defaults.items():
|
||||||
|
for item, opt in items.items():
|
||||||
|
if item == "helptext" or not opt["choices"]:
|
||||||
|
continue
|
||||||
|
opt_value = self.config.get(section, item)
|
||||||
|
if opt_value.lower() == "none" and any(choice.lower() == "none"
|
||||||
|
for choice in opt["choices"]):
|
||||||
|
continue
|
||||||
|
if opt_value not in opt["choices"]:
|
||||||
|
default = str(opt["default"])
|
||||||
|
logger.warning("'%s' is not a valid config choice for '%s': '%s'. Defaulting "
|
||||||
|
"to: '%s'", opt_value, section, item, default)
|
||||||
|
self.config.set(section, item, default)
|
||||||
|
logger.debug("Checked config choices")
|
||||||
|
|
||||||
|
def check_config_change(self):
|
||||||
|
""" Check whether new default items have been added or removed
|
||||||
|
from the config file compared to saved version """
|
||||||
|
if set(self.config.sections()) != set(self.defaults.keys()):
|
||||||
|
logger.debug("Default config has new section(s)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
for section, items in self.defaults.items():
|
||||||
|
opts = [opt for opt in items.keys() if opt != "helptext"]
|
||||||
|
exists = [opt for opt in self.config[section].keys()
|
||||||
|
if not opt.startswith(("# ", "\n# "))]
|
||||||
|
if set(exists) != set(opts):
|
||||||
|
logger.debug("Default config has new item(s)")
|
||||||
|
return True
|
||||||
|
logger.debug("Default config has not changed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def handle_config(self):
|
||||||
|
""" Handle the config """
|
||||||
|
logger.debug("Handling config")
|
||||||
|
if not self.check_exists():
|
||||||
|
self.create_default()
|
||||||
|
self.load_config()
|
||||||
|
self.validate_config()
|
||||||
|
logger.debug("Handled config")
|
|
@ -3,7 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module
|
from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module
|
||||||
from lib.aligner import Extract as AlignerExtract, get_align_mat
|
from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@ -12,14 +12,13 @@ class DetectedFace():
|
||||||
""" Detected face and landmark information """
|
""" Detected face and landmark information """
|
||||||
def __init__( # pylint: disable=invalid-name
|
def __init__( # pylint: disable=invalid-name
|
||||||
self, image=None, x=None, w=None, y=None, h=None,
|
self, image=None, x=None, w=None, y=None, h=None,
|
||||||
frame_dims=None, landmarksXY=None):
|
landmarksXY=None):
|
||||||
logger.trace("Initializing %s", self.__class__.__name__)
|
logger.trace("Initializing %s", self.__class__.__name__)
|
||||||
self.image = image
|
self.image = image
|
||||||
self.x = x
|
self.x = x
|
||||||
self.w = w
|
self.w = w
|
||||||
self.y = y
|
self.y = y
|
||||||
self.h = h
|
self.h = h
|
||||||
self.frame_dims = frame_dims
|
|
||||||
self.landmarksXY = landmarksXY
|
self.landmarksXY = landmarksXY
|
||||||
self.hash = None # Hash must be set when the file is saved due to image compression
|
self.hash = None # Hash must be set when the file is saved due to image compression
|
||||||
|
|
||||||
|
@ -63,17 +62,12 @@ class DetectedFace():
|
||||||
self.x: self.x + self.w]
|
self.x: self.x + self.w]
|
||||||
|
|
||||||
def to_alignment(self):
|
def to_alignment(self):
|
||||||
""" Convert a detected face to alignment dict
|
""" Convert a detected face to alignment dict """
|
||||||
|
|
||||||
NB: frame_dims should be the height and width
|
|
||||||
of the original frame. """
|
|
||||||
|
|
||||||
alignment = dict()
|
alignment = dict()
|
||||||
alignment["x"] = self.x
|
alignment["x"] = self.x
|
||||||
alignment["w"] = self.w
|
alignment["w"] = self.w
|
||||||
alignment["y"] = self.y
|
alignment["y"] = self.y
|
||||||
alignment["h"] = self.h
|
alignment["h"] = self.h
|
||||||
alignment["frame_dims"] = self.frame_dims
|
|
||||||
alignment["landmarksXY"] = self.landmarksXY
|
alignment["landmarksXY"] = self.landmarksXY
|
||||||
alignment["hash"] = self.hash
|
alignment["hash"] = self.hash
|
||||||
logger.trace("Returning: %s", alignment)
|
logger.trace("Returning: %s", alignment)
|
||||||
|
@ -87,23 +81,22 @@ class DetectedFace():
|
||||||
self.w = alignment["w"]
|
self.w = alignment["w"]
|
||||||
self.y = alignment["y"]
|
self.y = alignment["y"]
|
||||||
self.h = alignment["h"]
|
self.h = alignment["h"]
|
||||||
self.frame_dims = alignment["frame_dims"]
|
|
||||||
self.landmarksXY = alignment["landmarksXY"]
|
self.landmarksXY = alignment["landmarksXY"]
|
||||||
# Manual tool does not know the final hash so default to None
|
# Manual tool does not know the final hash so default to None
|
||||||
self.hash = alignment.get("hash", None)
|
self.hash = alignment.get("hash", None)
|
||||||
if image is not None and image.any():
|
if image is not None and image.any():
|
||||||
self.image_to_face(image)
|
self.image_to_face(image)
|
||||||
logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, "
|
logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, "
|
||||||
"frame_dims: %s, landmarks: %s)",
|
"landmarks: %s)",
|
||||||
self.x, self.w, self.y, self.h, self.frame_dims, self.landmarksXY)
|
self.x, self.w, self.y, self.h, self.landmarksXY)
|
||||||
|
|
||||||
# <<< Aligned Face methods and properties >>> #
|
# <<< Aligned Face methods and properties >>> #
|
||||||
def load_aligned(self, image, size=256, padding=48, align_eyes=False):
|
def load_aligned(self, image, size=256, align_eyes=False):
|
||||||
""" No need to load aligned information for all uses of this
|
""" No need to load aligned information for all uses of this
|
||||||
class, so only call this to load the information for easy
|
class, so only call this to load the information for easy
|
||||||
reference to aligned properties for this face """
|
reference to aligned properties for this face """
|
||||||
logger.trace("Loading aligned face: (size: %s, padding: %s, align_eyes: %s)",
|
logger.trace("Loading aligned face: (size: %s, align_eyes: %s)", size, align_eyes)
|
||||||
size, padding, align_eyes)
|
padding = int(size * 0.1875)
|
||||||
self.aligned["size"] = size
|
self.aligned["size"] = size
|
||||||
self.aligned["padding"] = padding
|
self.aligned["padding"] = padding
|
||||||
self.aligned["align_eyes"] = align_eyes
|
self.aligned["align_eyes"] = align_eyes
|
||||||
|
@ -153,3 +146,8 @@ class DetectedFace():
|
||||||
self.aligned["padding"])
|
self.aligned["padding"])
|
||||||
logger.trace("Returning: %s", mat)
|
logger.trace("Returning: %s", mat)
|
||||||
return mat
|
return mat
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adjusted_interpolators(self):
|
||||||
|
""" Return the interpolator and reverse interpolator for the adjusted matrix """
|
||||||
|
return get_matrix_scaling(self.adjusted_matrix)
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from lib.gui.command import CommandNotebook
|
from lib.gui.command import CommandNotebook
|
||||||
from lib.gui.display import DisplayNotebook
|
from lib.gui.display import DisplayNotebook
|
||||||
from lib.gui.options import CliOptions, Config
|
from lib.gui.options import CliOptions
|
||||||
from lib.gui.stats import CurrentSession
|
from lib.gui.menu import MainMenuBar
|
||||||
|
from lib.gui.popup_configure import popup_config
|
||||||
|
from lib.gui.stats import Session
|
||||||
from lib.gui.statusbar import StatusBar
|
from lib.gui.statusbar import StatusBar
|
||||||
from lib.gui.utils import ConsoleOut, Images
|
from lib.gui.utils import ConsoleOut, get_config, get_images, initialize_config, initialize_images
|
||||||
from lib.gui.wrapper import ProcessWrapper
|
from lib.gui.wrapper import ProcessWrapper
|
||||||
|
|
|
@ -5,44 +5,42 @@ import logging
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
from .options import Config
|
|
||||||
from .tooltip import Tooltip
|
from .tooltip import Tooltip
|
||||||
from .utils import ContextMenu, Images, FileHandler
|
from .utils import ContextMenu, FileHandler, get_images, get_config, set_slider_rounding
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class CommandNotebook(ttk.Notebook):
|
class CommandNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
|
||||||
""" Frame to hold each individual tab of the command notebook """
|
""" Frame to hold each individual tab of the command notebook """
|
||||||
|
|
||||||
def __init__(self, parent, cli_options, tk_vars, scaling_factor):
|
def __init__(self, parent):
|
||||||
logger.debug("Initializing %s: (parent: %s, cli_options: %s, tk_vars: %s, "
|
logger.debug("Initializing %s: (parent: %s)", self.__class__.__name__, parent)
|
||||||
"scaling_factor: %s", self.__class__.__name__, parent, cli_options,
|
scaling_factor = get_config().scaling_factor
|
||||||
tk_vars, scaling_factor)
|
|
||||||
width = int(420 * scaling_factor)
|
width = int(420 * scaling_factor)
|
||||||
height = int(500 * scaling_factor)
|
height = int(500 * scaling_factor)
|
||||||
ttk.Notebook.__init__(self, parent, width=width, height=height)
|
ttk.Notebook.__init__(self, parent, width=width, height=height)
|
||||||
parent.add(self)
|
parent.add(self)
|
||||||
|
|
||||||
self.cli_opts = cli_options
|
|
||||||
self.tk_vars = tk_vars
|
|
||||||
self.actionbtns = dict()
|
self.actionbtns = dict()
|
||||||
|
|
||||||
self.set_running_task_trace()
|
self.set_running_task_trace()
|
||||||
self.build_tabs()
|
self.build_tabs()
|
||||||
|
get_config().command_notebook = self
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_running_task_trace(self):
|
def set_running_task_trace(self):
|
||||||
""" Set trigger action for the running task
|
""" Set trigger action for the running task
|
||||||
to change the action buttons text and command """
|
to change the action buttons text and command """
|
||||||
logger.debug("Set running trace")
|
logger.debug("Set running trace")
|
||||||
self.tk_vars["runningtask"].trace("w", self.change_action_button)
|
tk_vars = get_config().tk_vars
|
||||||
|
tk_vars["runningtask"].trace("w", self.change_action_button)
|
||||||
|
|
||||||
def build_tabs(self):
|
def build_tabs(self):
|
||||||
""" Build the tabs for the relevant command """
|
""" Build the tabs for the relevant command """
|
||||||
logger.debug("Build Tabs")
|
logger.debug("Build Tabs")
|
||||||
for category in self.cli_opts.categories:
|
cli_opts = get_config().cli_opts
|
||||||
cmdlist = self.cli_opts.commands[category]
|
for category in cli_opts.categories:
|
||||||
|
cmdlist = cli_opts.commands[category]
|
||||||
for command in cmdlist:
|
for command in cmdlist:
|
||||||
title = command.title()
|
title = command.title()
|
||||||
commandtab = CommandTab(self, category, command)
|
commandtab = CommandTab(self, category, command)
|
||||||
|
@ -52,9 +50,11 @@ class CommandNotebook(ttk.Notebook):
|
||||||
def change_action_button(self, *args):
|
def change_action_button(self, *args):
|
||||||
""" Change the action button to relevant control """
|
""" Change the action button to relevant control """
|
||||||
logger.debug("Update Action Buttons: (args: %s", args)
|
logger.debug("Update Action Buttons: (args: %s", args)
|
||||||
|
tk_vars = get_config().tk_vars
|
||||||
|
|
||||||
for cmd in self.actionbtns.keys():
|
for cmd in self.actionbtns.keys():
|
||||||
btnact = self.actionbtns[cmd]
|
btnact = self.actionbtns[cmd]
|
||||||
if self.tk_vars["runningtask"].get():
|
if tk_vars["runningtask"].get():
|
||||||
ttl = "Terminate"
|
ttl = "Terminate"
|
||||||
hlp = "Exit the running process"
|
hlp = "Exit the running process"
|
||||||
else:
|
else:
|
||||||
|
@ -65,7 +65,7 @@ class CommandNotebook(ttk.Notebook):
|
||||||
Tooltip(btnact, text=hlp, wraplength=200)
|
Tooltip(btnact, text=hlp, wraplength=200)
|
||||||
|
|
||||||
|
|
||||||
class CommandTab(ttk.Frame):
|
class CommandTab(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" Frame to hold each individual tab of the command notebook """
|
""" Frame to hold each individual tab of the command notebook """
|
||||||
|
|
||||||
def __init__(self, parent, category, command):
|
def __init__(self, parent, category, command):
|
||||||
|
@ -74,9 +74,7 @@ class CommandTab(ttk.Frame):
|
||||||
ttk.Frame.__init__(self, parent)
|
ttk.Frame.__init__(self, parent)
|
||||||
|
|
||||||
self.category = category
|
self.category = category
|
||||||
self.cli_opts = parent.cli_opts
|
|
||||||
self.actionbtns = parent.actionbtns
|
self.actionbtns = parent.actionbtns
|
||||||
self.tk_vars = parent.tk_vars
|
|
||||||
self.command = command
|
self.command = command
|
||||||
|
|
||||||
self.build_tab()
|
self.build_tab()
|
||||||
|
@ -100,7 +98,7 @@ class CommandTab(ttk.Frame):
|
||||||
logger.debug("Added frame seperator")
|
logger.debug("Added frame seperator")
|
||||||
|
|
||||||
|
|
||||||
class OptionsFrame(ttk.Frame):
|
class OptionsFrame(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" Options Frame - Holds the Options for each command """
|
""" Options Frame - Holds the Options for each command """
|
||||||
|
|
||||||
def __init__(self, parent):
|
def __init__(self, parent):
|
||||||
|
@ -108,7 +106,6 @@ class OptionsFrame(ttk.Frame):
|
||||||
ttk.Frame.__init__(self, parent)
|
ttk.Frame.__init__(self, parent)
|
||||||
self.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
self.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
self.opts = parent.cli_opts
|
|
||||||
self.command = parent.command
|
self.command = parent.command
|
||||||
|
|
||||||
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
|
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
|
||||||
|
@ -121,7 +118,8 @@ class OptionsFrame(ttk.Frame):
|
||||||
self.chkbtns = self.checkbuttons_frame()
|
self.chkbtns = self.checkbuttons_frame()
|
||||||
|
|
||||||
self.build_frame()
|
self.build_frame()
|
||||||
self.opts.set_context_option(self.command)
|
cli_opts = get_config().cli_opts
|
||||||
|
cli_opts.set_context_option(self.command)
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def checkbuttons_frame(self):
|
def checkbuttons_frame(self):
|
||||||
|
@ -150,7 +148,8 @@ class OptionsFrame(ttk.Frame):
|
||||||
self.add_scrollbar()
|
self.add_scrollbar()
|
||||||
self.canvas.bind("<Configure>", self.resize_frame)
|
self.canvas.bind("<Configure>", self.resize_frame)
|
||||||
|
|
||||||
for option in self.opts.gen_command_options(self.command):
|
cli_opts = get_config().cli_opts
|
||||||
|
for option in cli_opts.gen_command_options(self.command):
|
||||||
optioncontrol = OptionControl(self.command,
|
optioncontrol = OptionControl(self.command,
|
||||||
option,
|
option,
|
||||||
self.optsframe,
|
self.optsframe,
|
||||||
|
@ -170,7 +169,7 @@ class OptionsFrame(ttk.Frame):
|
||||||
self.optsframe.bind("<Configure>", self.update_scrollbar)
|
self.optsframe.bind("<Configure>", self.update_scrollbar)
|
||||||
logger.debug("Added Options Scrollbar")
|
logger.debug("Added Options Scrollbar")
|
||||||
|
|
||||||
def update_scrollbar(self, event):
|
def update_scrollbar(self, event): # pylint: disable=unused-argument
|
||||||
""" Update the options frame scrollbar """
|
""" Update the options frame scrollbar """
|
||||||
self.canvas.configure(scrollregion=self.canvas.bbox("all"))
|
self.canvas.configure(scrollregion=self.canvas.bbox("all"))
|
||||||
|
|
||||||
|
@ -207,6 +206,7 @@ class OptionControl():
|
||||||
if ctl == ttk.Checkbutton:
|
if ctl == ttk.Checkbutton:
|
||||||
dflt = self.option.get("default", False)
|
dflt = self.option.get("default", False)
|
||||||
choices = self.option["choices"] if ctl == ttk.Combobox else None
|
choices = self.option["choices"] if ctl == ttk.Combobox else None
|
||||||
|
min_max = self.option["min_max"] if ctl == ttk.Scale else None
|
||||||
|
|
||||||
ctlframe = self.build_one_control_frame()
|
ctlframe = self.build_one_control_frame()
|
||||||
|
|
||||||
|
@ -217,6 +217,7 @@ class OptionControl():
|
||||||
self.option["value"] = self.build_one_control(ctlframe,
|
self.option["value"] = self.build_one_control(ctlframe,
|
||||||
ctlvars,
|
ctlvars,
|
||||||
choices,
|
choices,
|
||||||
|
min_max,
|
||||||
sysbrowser)
|
sysbrowser)
|
||||||
logger.debug("Built option control")
|
logger.debug("Built option control")
|
||||||
|
|
||||||
|
@ -228,6 +229,7 @@ class OptionControl():
|
||||||
ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'")
|
ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'")
|
||||||
else:
|
else:
|
||||||
ctlhelp = " ".join(ctlhelp.split())
|
ctlhelp = " ".join(ctlhelp.split())
|
||||||
|
ctlhelp = ctlhelp.replace("%%", "%")
|
||||||
ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". "))
|
ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". "))
|
||||||
ctlhelp = ctltitle + " - " + ctlhelp
|
ctlhelp = ctltitle + " - " + ctlhelp
|
||||||
logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp)
|
logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp)
|
||||||
|
@ -249,15 +251,14 @@ class OptionControl():
|
||||||
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
|
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
|
||||||
logger.debug("Built control label: '%s'", control_title)
|
logger.debug("Built control label: '%s'", control_title)
|
||||||
|
|
||||||
def build_one_control(self, frame, controlvars, choices, sysbrowser):
|
def build_one_control(self, frame, controlvars, choices, min_max, sysbrowser):
|
||||||
""" Build and place the option controls """
|
""" Build and place the option controls """
|
||||||
logger.debug("Build control: (controlvars: %s, choices: %s, sysbrowser: %s",
|
logger.debug("Build control: (controlvars: %s, choices: %s, min_max: %s, sysbrowser: %s",
|
||||||
controlvars, choices, sysbrowser)
|
controlvars, choices, min_max, sysbrowser)
|
||||||
control, control_title, default, helptext = controlvars
|
control, control_title, default, helptext = controlvars
|
||||||
default = default if default is not None else ""
|
default = default if default is not None else ""
|
||||||
|
|
||||||
var = tk.BooleanVar(
|
var = tk.BooleanVar(frame) if control == ttk.Checkbutton else tk.StringVar(frame)
|
||||||
frame) if control == ttk.Checkbutton else tk.StringVar(frame)
|
|
||||||
var.set(default)
|
var.set(default)
|
||||||
|
|
||||||
if sysbrowser:
|
if sysbrowser:
|
||||||
|
@ -268,6 +269,12 @@ class OptionControl():
|
||||||
control_title,
|
control_title,
|
||||||
var,
|
var,
|
||||||
helptext)
|
helptext)
|
||||||
|
elif control == ttk.Scale:
|
||||||
|
self.slider_control(control,
|
||||||
|
frame,
|
||||||
|
var,
|
||||||
|
min_max,
|
||||||
|
helptext)
|
||||||
else:
|
else:
|
||||||
self.control_to_optionsframe(control,
|
self.control_to_optionsframe(control,
|
||||||
frame,
|
frame,
|
||||||
|
@ -292,6 +299,29 @@ class OptionControl():
|
||||||
Tooltip(ctl, text=helptext, wraplength=200)
|
Tooltip(ctl, text=helptext, wraplength=200)
|
||||||
logger.debug("Added control checkframe: '%s'", control_title)
|
logger.debug("Added control checkframe: '%s'", control_title)
|
||||||
|
|
||||||
|
def slider_control(self, control, frame, tk_var, min_max, helptext):
|
||||||
|
""" A slider control with corresponding Entry box """
|
||||||
|
logger.debug("Add slider control to Options Frame: %s", control)
|
||||||
|
d_type = self.option.get("type", float)
|
||||||
|
rnd = self.option.get("rounding", 2) if d_type == float else self.option.get("rounding", 1)
|
||||||
|
|
||||||
|
tbox = ttk.Entry(frame, width=8, textvariable=tk_var, justify=tk.RIGHT)
|
||||||
|
tbox.pack(padx=(0, 5), side=tk.RIGHT)
|
||||||
|
ctl = control(
|
||||||
|
frame,
|
||||||
|
variable=tk_var,
|
||||||
|
command=lambda val, var=tk_var, dt=d_type, rn=rnd, mm=min_max:
|
||||||
|
set_slider_rounding(val, var, dt, rn, mm))
|
||||||
|
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
|
||||||
|
rc_menu = ContextMenu(ctl)
|
||||||
|
rc_menu.cm_bind()
|
||||||
|
ctl["from_"] = min_max[0]
|
||||||
|
ctl["to"] = min_max[1]
|
||||||
|
|
||||||
|
Tooltip(ctl, text=helptext, wraplength=720)
|
||||||
|
Tooltip(tbox, text=helptext, wraplength=720)
|
||||||
|
logger.debug("Added slider control to Options Frame: %s", control)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def control_to_optionsframe(control, frame, var, choices, helptext):
|
def control_to_optionsframe(control, frame, var, choices, helptext):
|
||||||
""" Standard non-check buttons sit in the main options frame """
|
""" Standard non-check buttons sit in the main options frame """
|
||||||
|
@ -303,8 +333,7 @@ class OptionControl():
|
||||||
if control == ttk.Combobox:
|
if control == ttk.Combobox:
|
||||||
logger.debug("Adding combo choices: %s", choices)
|
logger.debug("Adding combo choices: %s", choices)
|
||||||
ctl["values"] = [choice for choice in choices]
|
ctl["values"] = [choice for choice in choices]
|
||||||
|
Tooltip(ctl, text=helptext, wraplength=920)
|
||||||
Tooltip(ctl, text=helptext, wraplength=720)
|
|
||||||
logger.debug("Added control to Options Frame: %s", control)
|
logger.debug("Added control to Options Frame: %s", control)
|
||||||
|
|
||||||
def add_browser_buttons(self, frame, sysbrowser, filepath):
|
def add_browser_buttons(self, frame, sysbrowser, filepath):
|
||||||
|
@ -312,7 +341,7 @@ class OptionControl():
|
||||||
logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'",
|
logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'",
|
||||||
sysbrowser, filepath)
|
sysbrowser, filepath)
|
||||||
for browser in sysbrowser:
|
for browser in sysbrowser:
|
||||||
img = Images().icons[browser]
|
img = get_images().icons[browser]
|
||||||
action = getattr(self, "ask_" + browser)
|
action = getattr(self, "ask_" + browser)
|
||||||
filetypes = self.option.get("filetypes", "default")
|
filetypes = self.option.get("filetypes", "default")
|
||||||
fileopn = ttk.Button(frame,
|
fileopn = ttk.Button(frame,
|
||||||
|
@ -351,7 +380,7 @@ class OptionControl():
|
||||||
filepath.set(filename)
|
filepath.set(filename)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ask_nothing(filepath, filetypes=None):
|
def ask_nothing(filepath, filetypes=None): # pylint: disable=unused-argument
|
||||||
""" Method that does nothing, used for disabling open/save pop up """
|
""" Method that does nothing, used for disabling open/save pop up """
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -370,7 +399,7 @@ class OptionControl():
|
||||||
filepath.set(filename)
|
filepath.set(filename)
|
||||||
|
|
||||||
|
|
||||||
class ActionFrame(ttk.Frame):
|
class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
"""Action Frame - Displays action controls for the command tab """
|
"""Action Frame - Displays action controls for the command tab """
|
||||||
|
|
||||||
def __init__(self, parent):
|
def __init__(self, parent):
|
||||||
|
@ -382,16 +411,16 @@ class ActionFrame(ttk.Frame):
|
||||||
self.title = self.command.title()
|
self.title = self.command.title()
|
||||||
|
|
||||||
self.add_action_button(parent.category,
|
self.add_action_button(parent.category,
|
||||||
parent.actionbtns,
|
parent.actionbtns)
|
||||||
parent.tk_vars)
|
self.add_util_buttons()
|
||||||
self.add_util_buttons(parent.cli_opts, parent.tk_vars)
|
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def add_action_button(self, category, actionbtns, tk_vars):
|
def add_action_button(self, category, actionbtns):
|
||||||
""" Add the action buttons for page """
|
""" Add the action buttons for page """
|
||||||
logger.debug("Add action buttons: '%s'", self.title)
|
logger.debug("Add action buttons: '%s'", self.title)
|
||||||
actframe = ttk.Frame(self)
|
actframe = ttk.Frame(self)
|
||||||
actframe.pack(fill=tk.X, side=tk.LEFT)
|
actframe.pack(fill=tk.X, side=tk.LEFT)
|
||||||
|
tk_vars = get_config().tk_vars
|
||||||
|
|
||||||
var_value = "{},{}".format(category, self.command)
|
var_value = "{},{}".format(category, self.command)
|
||||||
|
|
||||||
|
@ -415,17 +444,17 @@ class ActionFrame(ttk.Frame):
|
||||||
wraplength=200)
|
wraplength=200)
|
||||||
logger.debug("Added action buttons: '%s'", self.title)
|
logger.debug("Added action buttons: '%s'", self.title)
|
||||||
|
|
||||||
def add_util_buttons(self, cli_options, tk_vars):
|
def add_util_buttons(self):
|
||||||
""" Add the section utility buttons """
|
""" Add the section utility buttons """
|
||||||
logger.debug("Add util buttons")
|
logger.debug("Add util buttons")
|
||||||
utlframe = ttk.Frame(self)
|
utlframe = ttk.Frame(self)
|
||||||
utlframe.pack(side=tk.RIGHT)
|
utlframe.pack(side=tk.RIGHT)
|
||||||
|
|
||||||
config = Config(cli_options, tk_vars)
|
config = get_config()
|
||||||
for utl in ("load", "save", "clear", "reset"):
|
for utl in ("load", "save", "clear", "reset"):
|
||||||
logger.debug("Adding button: '%s'", utl)
|
logger.debug("Adding button: '%s'", utl)
|
||||||
img = Images().icons[utl]
|
img = get_images().icons[utl]
|
||||||
action_cls = config if utl in (("save", "load")) else cli_options
|
action_cls = config if utl in (("save", "load")) else config.cli_opts
|
||||||
action = getattr(action_cls, utl)
|
action = getattr(action_cls, utl)
|
||||||
btnutl = ttk.Button(utlframe,
|
btnutl = ttk.Button(utlframe,
|
||||||
image=img,
|
image=img,
|
||||||
|
|
|
@ -4,48 +4,56 @@
|
||||||
What is displayed in the Display Frame varies
|
What is displayed in the Display Frame varies
|
||||||
depending on what tasked is being run """
|
depending on what tasked is being run """
|
||||||
|
|
||||||
|
import logging
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
from .display_analysis import Analysis
|
from .display_analysis import Analysis
|
||||||
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
|
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
|
||||||
|
from .utils import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DisplayNotebook(ttk.Notebook):
|
class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
|
||||||
""" The display tabs """
|
""" The display tabs """
|
||||||
|
|
||||||
def __init__(self, parent, session, tk_vars, scaling_factor):
|
def __init__(self, parent):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
ttk.Notebook.__init__(self, parent, width=780)
|
ttk.Notebook.__init__(self, parent, width=780)
|
||||||
parent.add(self)
|
parent.add(self)
|
||||||
|
tk_vars = get_config().tk_vars
|
||||||
self.wrapper_var = tk_vars["display"]
|
self.wrapper_var = tk_vars["display"]
|
||||||
self.runningtask = tk_vars["runningtask"]
|
self.runningtask = tk_vars["runningtask"]
|
||||||
self.session = session
|
|
||||||
|
|
||||||
self.set_wrapper_var_trace()
|
self.set_wrapper_var_trace()
|
||||||
self.add_static_tabs(scaling_factor)
|
self.add_static_tabs()
|
||||||
self.static_tabs = [child for child in self.tabs()]
|
self.static_tabs = [child for child in self.tabs()]
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_wrapper_var_trace(self):
|
def set_wrapper_var_trace(self):
|
||||||
""" Set the trigger actions for the display vars
|
""" Set the trigger actions for the display vars
|
||||||
when they have been triggered in the Process Wrapper """
|
when they have been triggered in the Process Wrapper """
|
||||||
|
logger.debug("Setting wrapper var trace")
|
||||||
self.wrapper_var.trace("w", self.update_displaybook)
|
self.wrapper_var.trace("w", self.update_displaybook)
|
||||||
|
|
||||||
def add_static_tabs(self, scaling_factor):
|
def add_static_tabs(self):
|
||||||
""" Add tabs that are permanently available """
|
""" Add tabs that are permanently available """
|
||||||
|
logger.debug("Adding static tabs")
|
||||||
for tab in ("job queue", "analysis"):
|
for tab in ("job queue", "analysis"):
|
||||||
if tab == "job queue":
|
if tab == "job queue":
|
||||||
continue # Not yet implemented
|
continue # Not yet implemented
|
||||||
if tab == "analysis":
|
if tab == "analysis":
|
||||||
helptext = {"stats":
|
helptext = {"stats":
|
||||||
"Summary statistics for each training session"}
|
"Summary statistics for each training session"}
|
||||||
frame = Analysis(self, tab, helptext, scaling_factor)
|
frame = Analysis(self, tab, helptext)
|
||||||
else:
|
else:
|
||||||
frame = self.add_frame()
|
frame = self.add_frame()
|
||||||
self.add(frame, text=tab.title())
|
self.add(frame, text=tab.title())
|
||||||
|
|
||||||
def add_frame(self):
|
def add_frame(self):
|
||||||
""" Add a single frame for holding tab's contents """
|
""" Add a single frame for holding tab's contents """
|
||||||
|
logger.debug("Adding frame")
|
||||||
frame = ttk.Frame(self)
|
frame = ttk.Frame(self)
|
||||||
frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
|
frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
|
||||||
return frame
|
return frame
|
||||||
|
@ -58,32 +66,53 @@ class DisplayNotebook(ttk.Notebook):
|
||||||
|
|
||||||
def extract_tabs(self):
|
def extract_tabs(self):
|
||||||
""" Build the extract tabs """
|
""" Build the extract tabs """
|
||||||
|
logger.debug("Build extract tabs")
|
||||||
helptext = ("Updates preview from output every 5 "
|
helptext = ("Updates preview from output every 5 "
|
||||||
"seconds to limit disk contention")
|
"seconds to limit disk contention")
|
||||||
PreviewExtract(self, "preview", helptext, 5000)
|
PreviewExtract(self, "preview", helptext, 5000)
|
||||||
|
logger.debug("Built extract tabs")
|
||||||
|
|
||||||
def train_tabs(self):
|
def train_tabs(self):
|
||||||
""" Build the train tabs """
|
""" Build the train tabs """
|
||||||
|
logger.debug("Build train tabs")
|
||||||
for tab in ("graph", "preview"):
|
for tab in ("graph", "preview"):
|
||||||
if tab == "graph":
|
if tab == "graph":
|
||||||
helptext = "Graph showing Loss vs Iterations"
|
helptext = "Graph showing Loss vs Iterations"
|
||||||
GraphDisplay(self, "graph", helptext, 5000)
|
GraphDisplay(self, "graph", helptext, 5000)
|
||||||
elif tab == "preview":
|
elif tab == "preview":
|
||||||
helptext = "Training preview. Updated on every save iteration"
|
helptext = "Training preview. Updated on every save iteration"
|
||||||
PreviewTrain(self, "preview", helptext, 5000)
|
PreviewTrain(self, "preview", helptext, 1000)
|
||||||
|
logger.debug("Built train tabs")
|
||||||
|
|
||||||
def convert_tabs(self):
|
def convert_tabs(self):
|
||||||
""" Build the convert tabs
|
""" Build the convert tabs
|
||||||
Currently identical to Extract, so just call that """
|
Currently identical to Extract, so just call that """
|
||||||
|
logger.debug("Build convert tabs")
|
||||||
self.extract_tabs()
|
self.extract_tabs()
|
||||||
|
logger.debug("Built convert tabs")
|
||||||
|
|
||||||
def remove_tabs(self):
|
def remove_tabs(self):
|
||||||
""" Remove all command specific tabs """
|
""" Remove all command specific tabs """
|
||||||
for child in self.tabs():
|
for child in self.tabs():
|
||||||
if child not in self.static_tabs:
|
if child in self.static_tabs:
|
||||||
self.forget(child)
|
continue
|
||||||
|
logger.debug("removing child: %s", child)
|
||||||
|
child_name = child.split(".")[-1]
|
||||||
|
child_object = self.children[child_name]
|
||||||
|
self.destroy_tabs_children(child_object)
|
||||||
|
self.forget(child)
|
||||||
|
|
||||||
def update_displaybook(self, *args):
|
@staticmethod
|
||||||
|
def destroy_tabs_children(tab):
|
||||||
|
""" Destroy all tabs children
|
||||||
|
Children must be destroyed as forget only hides display
|
||||||
|
"""
|
||||||
|
logger.debug("Destroying children for tab: %s", tab)
|
||||||
|
for child in tab.winfo_children():
|
||||||
|
logger.debug("Destroying child: %s", child)
|
||||||
|
child.destroy()
|
||||||
|
|
||||||
|
def update_displaybook(self, *args): # pylint: disable=unused-argument
|
||||||
""" Set the display tabs based on executing task """
|
""" Set the display tabs based on executing task """
|
||||||
command = self.wrapper_var.get()
|
command = self.wrapper_var.get()
|
||||||
self.remove_tabs()
|
self.remove_tabs()
|
||||||
|
|
|
@ -2,104 +2,138 @@
|
||||||
""" Analysis tab of Display Frame of the Faceswap GUI """
|
""" Analysis tab of Display Frame of the Faceswap GUI """
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
from .display_graph import SessionGraph
|
from .display_graph import SessionGraph
|
||||||
from .display_page import DisplayPage
|
from .display_page import DisplayPage
|
||||||
from .stats import Calculations, SavedSessions, SessionsSummary, SessionsTotals
|
from .stats import Calculations, Session
|
||||||
from .tooltip import Tooltip
|
from .tooltip import Tooltip
|
||||||
from .utils import Images, FileHandler
|
from .utils import FileHandler, get_config, get_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
|
class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
|
||||||
""" Session analysis tab """
|
""" Session analysis tab """
|
||||||
def __init__(self, parent, tabname, helptext, scaling_factor):
|
def __init__(self, parent, tabname, helptext):
|
||||||
DisplayPage.__init__(self, parent, tabname, helptext)
|
logger.debug("Initializing: %s: (parent, %s, tabname: '%s', helptext: '%s')",
|
||||||
|
self.__class__.__name__, parent, tabname, helptext)
|
||||||
|
super().__init__(parent, tabname, helptext)
|
||||||
|
|
||||||
self.summary = None
|
self.summary = None
|
||||||
|
self.session = None
|
||||||
self.add_options()
|
self.add_options()
|
||||||
self.add_main_frame(scaling_factor)
|
self.add_main_frame()
|
||||||
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_vars(self):
|
def set_vars(self):
|
||||||
""" Analysis specific vars """
|
""" Analysis specific vars """
|
||||||
selected_id = tk.StringVar()
|
selected_id = tk.StringVar()
|
||||||
filename = tk.StringVar()
|
return {"selected_id": selected_id}
|
||||||
return {"selected_id": selected_id,
|
|
||||||
"filename": filename}
|
|
||||||
|
|
||||||
def add_main_frame(self, scaling_factor):
|
def add_main_frame(self):
|
||||||
""" Add the main frame to the subnotebook
|
""" Add the main frame to the subnotebook
|
||||||
to hold stats and session data """
|
to hold stats and session data """
|
||||||
|
logger.debug("Adding main frame")
|
||||||
mainframe = self.subnotebook_add_page("stats")
|
mainframe = self.subnotebook_add_page("stats")
|
||||||
self.stats = StatsData(mainframe,
|
self.stats = StatsData(mainframe,
|
||||||
self.vars["filename"],
|
|
||||||
self.vars["selected_id"],
|
self.vars["selected_id"],
|
||||||
self.helptext["stats"],
|
self.helptext["stats"])
|
||||||
scaling_factor)
|
logger.debug("Added main frame")
|
||||||
|
|
||||||
def add_options(self):
|
def add_options(self):
|
||||||
""" Add the options bar """
|
""" Add the options bar """
|
||||||
|
logger.debug("Adding options")
|
||||||
self.reset_session_info()
|
self.reset_session_info()
|
||||||
options = Options(self)
|
options = Options(self)
|
||||||
options.add_options()
|
options.add_options()
|
||||||
|
logger.debug("Added options")
|
||||||
|
|
||||||
def reset_session_info(self):
|
def reset_session_info(self):
|
||||||
""" Reset the session info status to default """
|
""" Reset the session info status to default """
|
||||||
self.vars["filename"].set(None)
|
logger.debug("Resetting session info")
|
||||||
self.set_info("No session data loaded")
|
self.set_info("No session data loaded")
|
||||||
|
|
||||||
def load_session(self):
|
def load_session(self):
|
||||||
""" Load previously saved sessions """
|
""" Load previously saved sessions """
|
||||||
|
logger.debug("Loading session")
|
||||||
self.clear_session()
|
self.clear_session()
|
||||||
filename = FileHandler("open", "session").retfile
|
fullpath = FileHandler("filename", "state").retfile
|
||||||
if not filename:
|
if not fullpath:
|
||||||
return
|
return
|
||||||
filename = filename.name
|
logger.debug("state_file: '%s'", fullpath)
|
||||||
loaded_data = SavedSessions(filename).sessions
|
model_dir, state_file = os.path.split(fullpath)
|
||||||
msg = filename
|
logger.debug("model_dir: '%s'", model_dir)
|
||||||
if len(filename) > 70:
|
model_name = self.get_model_name(model_dir, state_file)
|
||||||
msg = "...{}".format(filename[-70:])
|
if not model_name:
|
||||||
self.set_session_summary(loaded_data, msg)
|
return
|
||||||
self.vars["filename"].set(filename)
|
self.session = Session(model_dir=model_dir, model_name=model_name)
|
||||||
|
self.session.initialize_session(is_training=False)
|
||||||
|
msg = os.path.split(state_file)[0]
|
||||||
|
if len(msg) > 70:
|
||||||
|
msg = "...{}".format(msg[-70:])
|
||||||
|
self.set_session_summary(msg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_name(model_dir, state_file):
|
||||||
|
""" Get the state file from the model directory """
|
||||||
|
logger.debug("Getting model name")
|
||||||
|
model_name = state_file.replace("_state.json", "")
|
||||||
|
logger.debug("model_name: %s", model_name)
|
||||||
|
logs_dir = os.path.join(model_dir, "{}_logs".format(model_name))
|
||||||
|
if not os.path.isdir(logs_dir):
|
||||||
|
logger.warning("No logs folder found in folder: '%s'", logs_dir)
|
||||||
|
return None
|
||||||
|
return model_name
|
||||||
|
|
||||||
def reset_session(self):
|
def reset_session(self):
|
||||||
""" Load previously saved sessions """
|
""" Reset currently training sessions """
|
||||||
|
logger.debug("Reset current training session")
|
||||||
self.clear_session()
|
self.clear_session()
|
||||||
if self.session.stats["iterations"] == 0:
|
session = get_config().session
|
||||||
|
if not session.initialized:
|
||||||
|
logger.debug("Training not running")
|
||||||
print("Training not running")
|
print("Training not running")
|
||||||
return
|
return
|
||||||
loaded_data = self.session.historical.sessions
|
|
||||||
msg = "Currently running training session"
|
msg = "Currently running training session"
|
||||||
self.set_session_summary(loaded_data, msg)
|
self.session = session
|
||||||
self.vars["filename"].set("Currently running training session")
|
self.set_session_summary(msg)
|
||||||
|
|
||||||
def set_session_summary(self, data, message):
|
def set_session_summary(self, message):
|
||||||
""" Set the summary data and info message """
|
""" Set the summary data and info message """
|
||||||
self.summary = SessionsSummary(data).summary
|
logger.debug("Setting session summary. (message: '%s')", message)
|
||||||
|
self.summary = self.session.full_summary
|
||||||
self.set_info("Session: {}".format(message))
|
self.set_info("Session: {}".format(message))
|
||||||
self.stats.loaded_data = data
|
self.stats.session = self.session
|
||||||
self.stats.tree_insert_data(self.summary)
|
self.stats.tree_insert_data(self.summary)
|
||||||
|
|
||||||
def clear_session(self):
|
def clear_session(self):
|
||||||
""" Clear sessions stats """
|
""" Clear sessions stats """
|
||||||
|
logger.debug("Clearing session")
|
||||||
self.summary = None
|
self.summary = None
|
||||||
self.stats.loaded_data = None
|
self.stats.session = None
|
||||||
self.stats.tree_clear()
|
self.stats.tree_clear()
|
||||||
self.reset_session_info()
|
self.reset_session_info()
|
||||||
|
|
||||||
def save_session(self):
|
def save_session(self):
|
||||||
""" Save sessions stats to csv """
|
""" Save sessions stats to csv """
|
||||||
|
logger.debug("Saving session")
|
||||||
if not self.summary:
|
if not self.summary:
|
||||||
|
logger.debug("No summary data loaded. Nothing to save")
|
||||||
print("No summary data loaded. Nothing to save")
|
print("No summary data loaded. Nothing to save")
|
||||||
return
|
return
|
||||||
savefile = FileHandler("save", "csv").retfile
|
savefile = FileHandler("save", "csv").retfile
|
||||||
if not savefile:
|
if not savefile:
|
||||||
|
logger.debug("No save file. Returning")
|
||||||
return
|
return
|
||||||
|
|
||||||
write_dicts = [val for val in self.summary.values()]
|
write_dicts = [val for val in self.summary.values()]
|
||||||
fieldnames = sorted(key for key in write_dicts[0].keys())
|
fieldnames = sorted(key for key in write_dicts[0].keys())
|
||||||
|
|
||||||
|
logger.debug("Saving to: '%s'", savefile)
|
||||||
with savefile as outfile:
|
with savefile as outfile:
|
||||||
csvout = csv.DictWriter(outfile, fieldnames)
|
csvout = csv.DictWriter(outfile, fieldnames)
|
||||||
csvout.writeheader()
|
csvout.writeheader()
|
||||||
|
@ -110,8 +144,10 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
|
||||||
class Options():
|
class Options():
|
||||||
""" Options bar of Analysis tab """
|
""" Options bar of Analysis tab """
|
||||||
def __init__(self, parent):
|
def __init__(self, parent):
|
||||||
|
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||||
self.optsframe = parent.optsframe
|
self.optsframe = parent.optsframe
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||||
|
|
||||||
def add_options(self):
|
def add_options(self):
|
||||||
""" Add the display tab options """
|
""" Add the display tab options """
|
||||||
|
@ -120,9 +156,10 @@ class Options():
|
||||||
def add_buttons(self):
|
def add_buttons(self):
|
||||||
""" Add the option buttons """
|
""" Add the option buttons """
|
||||||
for btntype in ("reset", "clear", "save", "load"):
|
for btntype in ("reset", "clear", "save", "load"):
|
||||||
|
logger.debug("Adding button: '%s'", btntype)
|
||||||
cmd = getattr(self.parent, "{}_session".format(btntype))
|
cmd = getattr(self.parent, "{}_session".format(btntype))
|
||||||
btn = ttk.Button(self.optsframe,
|
btn = ttk.Button(self.optsframe,
|
||||||
image=Images().icons[btntype],
|
image=get_images().icons[btntype],
|
||||||
command=cmd)
|
command=cmd)
|
||||||
btn.pack(padx=2, side=tk.RIGHT)
|
btn.pack(padx=2, side=tk.RIGHT)
|
||||||
hlp = self.set_help(btntype)
|
hlp = self.set_help(btntype)
|
||||||
|
@ -131,6 +168,7 @@ class Options():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_help(btntype):
|
def set_help(btntype):
|
||||||
""" Set the helptext for option buttons """
|
""" Set the helptext for option buttons """
|
||||||
|
logger.debug("Setting help")
|
||||||
hlp = ""
|
hlp = ""
|
||||||
if btntype == "reset":
|
if btntype == "reset":
|
||||||
hlp = "Load/Refresh stats for the currently training session"
|
hlp = "Load/Refresh stats for the currently training session"
|
||||||
|
@ -145,25 +183,15 @@ class Options():
|
||||||
|
|
||||||
class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" Stats frame of analysis tab """
|
""" Stats frame of analysis tab """
|
||||||
def __init__(self,
|
def __init__(self, parent, selected_id, helptext):
|
||||||
parent,
|
logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')",
|
||||||
filename,
|
self.__class__.__name__, parent, selected_id, helptext)
|
||||||
selected_id,
|
super().__init__(parent)
|
||||||
helptext,
|
self.pack(side=tk.TOP, padx=5, pady=5, expand=True, fill=tk.X, anchor=tk.N)
|
||||||
scaling_factor):
|
|
||||||
ttk.Frame.__init__(self, parent)
|
|
||||||
self.pack(side=tk.TOP,
|
|
||||||
padx=5,
|
|
||||||
pady=5,
|
|
||||||
expand=True,
|
|
||||||
fill=tk.X,
|
|
||||||
anchor=tk.N)
|
|
||||||
|
|
||||||
self.filename = filename
|
self.session = None # set when loading or clearing from parent
|
||||||
self.loaded_data = None
|
|
||||||
self.selected_id = selected_id
|
self.selected_id = selected_id
|
||||||
self.popup_positions = list()
|
self.popup_positions = list()
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
|
|
||||||
self.add_label()
|
self.add_label()
|
||||||
self.tree = ttk.Treeview(self, height=1, selectmode=tk.BROWSE)
|
self.tree = ttk.Treeview(self, height=1, selectmode=tk.BROWSE)
|
||||||
|
@ -171,14 +199,17 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
orient="vertical",
|
orient="vertical",
|
||||||
command=self.tree.yview)
|
command=self.tree.yview)
|
||||||
self.columns = self.tree_configure(helptext)
|
self.columns = self.tree_configure(helptext)
|
||||||
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||||
|
|
||||||
def add_label(self):
|
def add_label(self):
|
||||||
""" Add Treeview Title """
|
""" Add Treeview Title """
|
||||||
|
logger.debug("Adding Treeview title")
|
||||||
lbl = ttk.Label(self, text="Session Stats", anchor=tk.CENTER)
|
lbl = ttk.Label(self, text="Session Stats", anchor=tk.CENTER)
|
||||||
lbl.pack(side=tk.TOP, expand=True, fill=tk.X, padx=5, pady=5)
|
lbl.pack(side=tk.TOP, expand=True, fill=tk.X, padx=5, pady=5)
|
||||||
|
|
||||||
def tree_configure(self, helptext):
|
def tree_configure(self, helptext):
|
||||||
""" Build a treeview widget to hold the sessions stats """
|
""" Build a treeview widget to hold the sessions stats """
|
||||||
|
logger.debug("Configuring Treeview")
|
||||||
self.tree.configure(yscrollcommand=self.scrollbar.set)
|
self.tree.configure(yscrollcommand=self.scrollbar.set)
|
||||||
self.tree.tag_configure("total",
|
self.tree.tag_configure("total",
|
||||||
background="black",
|
background="black",
|
||||||
|
@ -191,6 +222,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
def tree_columns(self):
|
def tree_columns(self):
|
||||||
""" Add the columns to the totals treeview """
|
""" Add the columns to the totals treeview """
|
||||||
|
logger.debug("Adding Treeview columns")
|
||||||
columns = (("session", 40, "#"),
|
columns = (("session", 40, "#"),
|
||||||
("start", 130, None),
|
("start", 130, None),
|
||||||
("end", 130, None),
|
("end", 130, None),
|
||||||
|
@ -202,6 +234,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
for column in columns:
|
for column in columns:
|
||||||
text = column[2] if column[2] else column[0].title()
|
text = column[2] if column[2] else column[0].title()
|
||||||
|
logger.debug("Adding heading: '%s'", text)
|
||||||
self.tree.heading(column[0], text=text)
|
self.tree.heading(column[0], text=text)
|
||||||
self.tree.column(column[0],
|
self.tree.column(column[0],
|
||||||
width=column[1],
|
width=column[1],
|
||||||
|
@ -212,19 +245,21 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
return [column[0] for column in columns]
|
return [column[0] for column in columns]
|
||||||
|
|
||||||
def tree_insert_data(self, sessions):
|
def tree_insert_data(self, sessions_summary):
|
||||||
""" Insert the data into the totals treeview """
|
""" Insert the data into the totals treeview """
|
||||||
self.tree.configure(height=len(sessions))
|
logger.debug("Inserting treeview data")
|
||||||
|
self.tree.configure(height=len(sessions_summary))
|
||||||
|
|
||||||
for item in sessions:
|
for item in sessions_summary:
|
||||||
values = [item[column] for column in self.columns]
|
values = [item[column] for column in self.columns]
|
||||||
kwargs = {"values": values, "image": Images().icons["graph"]}
|
kwargs = {"values": values, "image": get_images().icons["graph"]}
|
||||||
if values[0] == "Total":
|
if values[0] == "Total":
|
||||||
kwargs["tags"] = "total"
|
kwargs["tags"] = "total"
|
||||||
self.tree.insert("", "end", **kwargs)
|
self.tree.insert("", "end", **kwargs)
|
||||||
|
|
||||||
def tree_clear(self):
|
def tree_clear(self):
|
||||||
""" Clear the totals tree """
|
""" Clear the totals tree """
|
||||||
|
logger.debug("Clearing treeview data")
|
||||||
self.tree.delete(* self.tree.get_children())
|
self.tree.delete(* self.tree.get_children())
|
||||||
self.tree.configure(height=1)
|
self.tree.configure(height=1)
|
||||||
|
|
||||||
|
@ -235,17 +270,22 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
selection = self.tree.focus()
|
selection = self.tree.focus()
|
||||||
values = self.tree.item(selection, "values")
|
values = self.tree.item(selection, "values")
|
||||||
if values:
|
if values:
|
||||||
|
logger.debug("Selected values: %s", values)
|
||||||
self.selected_id.set(values[0])
|
self.selected_id.set(values[0])
|
||||||
if region == "tree":
|
if region == "tree":
|
||||||
self.data_popup()
|
self.data_popup()
|
||||||
|
|
||||||
def data_popup(self):
|
def data_popup(self):
|
||||||
""" Pop up a window and control it's position """
|
""" Pop up a window and control it's position """
|
||||||
toplevel = SessionPopUp(self.loaded_data, self.selected_id.get())
|
logger.debug("Popping up data window")
|
||||||
|
scaling_factor = get_config().scaling_factor
|
||||||
|
toplevel = SessionPopUp(self.session.modeldir,
|
||||||
|
self.session.modelname,
|
||||||
|
self.selected_id.get())
|
||||||
toplevel.title(self.data_popup_title())
|
toplevel.title(self.data_popup_title())
|
||||||
position = self.data_popup_get_position()
|
position = self.data_popup_get_position()
|
||||||
height = int(720 * self.scaling_factor)
|
height = int(720 * scaling_factor)
|
||||||
width = int(400 * self.scaling_factor)
|
width = int(400 * scaling_factor)
|
||||||
toplevel.geometry("{}x{}+{}+{}".format(str(height),
|
toplevel.geometry("{}x{}+{}+{}".format(str(height),
|
||||||
str(width),
|
str(width),
|
||||||
str(position[0]),
|
str(position[0]),
|
||||||
|
@ -254,14 +294,17 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
def data_popup_title(self):
|
def data_popup_title(self):
|
||||||
""" Set the data popup title """
|
""" Set the data popup title """
|
||||||
|
logger.debug("Setting poup title")
|
||||||
selected_id = self.selected_id.get()
|
selected_id = self.selected_id.get()
|
||||||
title = "All Sessions"
|
title = "All Sessions"
|
||||||
if selected_id != "Total":
|
if selected_id != "Total":
|
||||||
title = "Session #{}".format(selected_id)
|
title = "{} Model: Session #{}".format(self.session.modelname.title(), selected_id)
|
||||||
return "{} - {}".format(title, self.filename.get())
|
logger.debug("Title: '%s'", title)
|
||||||
|
return "{} - {}".format(title, self.session.modeldir)
|
||||||
|
|
||||||
def data_popup_get_position(self):
|
def data_popup_get_position(self):
|
||||||
""" Get the position of the next window """
|
""" Get the position of the next window """
|
||||||
|
logger.debug("getting poup position")
|
||||||
init_pos = [120, 120]
|
init_pos = [120, 120]
|
||||||
pos = init_pos
|
pos = init_pos
|
||||||
while True:
|
while True:
|
||||||
|
@ -270,25 +313,33 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
break
|
break
|
||||||
pos = [item + 200 for item in pos]
|
pos = [item + 200 for item in pos]
|
||||||
init_pos, pos = self.data_popup_check_boundaries(init_pos, pos)
|
init_pos, pos = self.data_popup_check_boundaries(init_pos, pos)
|
||||||
|
logger.debug("Position: %s", pos)
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
def data_popup_check_boundaries(self, initial_position, position):
|
def data_popup_check_boundaries(self, initial_position, position):
|
||||||
""" Check that the popup remains within the screen boundaries """
|
""" Check that the popup remains within the screen boundaries """
|
||||||
|
logger.debug("Checking poup boundaries: (initial_position: %s, position: %s)",
|
||||||
|
initial_position, position)
|
||||||
boundary_x = self.winfo_screenwidth() - 120
|
boundary_x = self.winfo_screenwidth() - 120
|
||||||
boundary_y = self.winfo_screenheight() - 120
|
boundary_y = self.winfo_screenheight() - 120
|
||||||
if position[0] >= boundary_x or position[1] >= boundary_y:
|
if position[0] >= boundary_x or position[1] >= boundary_y:
|
||||||
initial_position = [initial_position[0] + 50, initial_position[1]]
|
initial_position = [initial_position[0] + 50, initial_position[1]]
|
||||||
position = initial_position
|
position = initial_position
|
||||||
|
logger.debug("Returning poup boundaries: (initial_position: %s, position: %s)",
|
||||||
|
initial_position, position)
|
||||||
return initial_position, position
|
return initial_position, position
|
||||||
|
|
||||||
|
|
||||||
class SessionPopUp(tk.Toplevel):
|
class SessionPopUp(tk.Toplevel):
|
||||||
""" Pop up for detailed grap/stats for selected session """
|
""" Pop up for detailed graph/stats for selected session """
|
||||||
def __init__(self, data, session_id):
|
def __init__(self, model_dir, model_name, session_id):
|
||||||
tk.Toplevel.__init__(self)
|
logger.debug("Initializing: %s: (model_dir: %s, model_name: %s, session_id: %s)",
|
||||||
|
self.__class__.__name__, model_dir, model_name, session_id)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.is_totals = session_id == "Total"
|
self.session_id = session_id
|
||||||
self.data = self.set_session_data(data, session_id)
|
self.session = Session(model_dir=model_dir, model_name=model_name)
|
||||||
|
self.initialize_session()
|
||||||
|
|
||||||
self.graph = None
|
self.graph = None
|
||||||
self.display_data = None
|
self.display_data = None
|
||||||
|
@ -296,25 +347,35 @@ class SessionPopUp(tk.Toplevel):
|
||||||
self.vars = dict()
|
self.vars = dict()
|
||||||
self.graph_initialised = False
|
self.graph_initialised = False
|
||||||
self.build()
|
self.build()
|
||||||
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_session_data(self, sessions, session_id):
|
@property
|
||||||
""" Set the correct list index based on the passed in session is """
|
def is_totals(self):
|
||||||
if self.is_totals:
|
""" Return True if these are totals else False """
|
||||||
data = SessionsTotals(sessions).stats
|
return bool(self.session_id == "Total")
|
||||||
else:
|
|
||||||
data = sessions[int(session_id) - 1]
|
def initialize_session(self):
|
||||||
return data
|
""" Initialize the session """
|
||||||
|
logger.debug("Initializing session")
|
||||||
|
kwargs = dict(is_training=False)
|
||||||
|
if not self.is_totals:
|
||||||
|
kwargs["session_id"] = int(self.session_id)
|
||||||
|
logger.debug("Session kwargs: %s", kwargs)
|
||||||
|
self.session.initialize_session(**kwargs)
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
""" Build the popup window """
|
""" Build the popup window """
|
||||||
|
logger.debug("Building popup")
|
||||||
optsframe, graphframe = self.layout_frames()
|
optsframe, graphframe = self.layout_frames()
|
||||||
|
|
||||||
self.opts_build(optsframe)
|
self.opts_build(optsframe)
|
||||||
self.compile_display_data()
|
self.compile_display_data()
|
||||||
self.graph_build(graphframe)
|
self.graph_build(graphframe)
|
||||||
|
logger.debug("Built popup")
|
||||||
|
|
||||||
def layout_frames(self):
|
def layout_frames(self):
|
||||||
""" Top level container frames """
|
""" Top level container frames """
|
||||||
|
logger.debug("Layout frames")
|
||||||
leftframe = ttk.Frame(self)
|
leftframe = ttk.Frame(self)
|
||||||
leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5)
|
leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5)
|
||||||
|
|
||||||
|
@ -323,20 +384,25 @@ class SessionPopUp(tk.Toplevel):
|
||||||
|
|
||||||
rightframe = ttk.Frame(self)
|
rightframe = ttk.Frame(self)
|
||||||
rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True)
|
rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True)
|
||||||
|
logger.debug("Laid out frames")
|
||||||
|
|
||||||
return leftframe, rightframe
|
return leftframe, rightframe
|
||||||
|
|
||||||
def opts_build(self, frame):
|
def opts_build(self, frame):
|
||||||
""" Options in options to the optsframe """
|
""" Build Options into the options frame """
|
||||||
|
logger.debug("Building Options")
|
||||||
self.opts_combobox(frame)
|
self.opts_combobox(frame)
|
||||||
self.opts_checkbuttons(frame)
|
self.opts_checkbuttons(frame)
|
||||||
|
self.opts_loss_keys(frame)
|
||||||
self.opts_entry(frame)
|
self.opts_entry(frame)
|
||||||
self.opts_buttons(frame)
|
self.opts_buttons(frame)
|
||||||
sep = ttk.Frame(frame, height=2, relief=tk.RIDGE)
|
sep = ttk.Frame(frame, height=2, relief=tk.RIDGE)
|
||||||
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
|
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
|
||||||
|
logger.debug("Built Options")
|
||||||
|
|
||||||
def opts_combobox(self, frame):
|
def opts_combobox(self, frame):
|
||||||
""" Add the options combo boxes """
|
""" Add the options combo boxes """
|
||||||
|
logger.debug("Building Combo boxes")
|
||||||
choices = {"Display": ("Loss", "Rate"),
|
choices = {"Display": ("Loss", "Rate"),
|
||||||
"Scale": ("Linear", "Log")}
|
"Scale": ("Linear", "Log")}
|
||||||
|
|
||||||
|
@ -362,9 +428,11 @@ class SessionPopUp(tk.Toplevel):
|
||||||
|
|
||||||
hlp = self.set_help(item)
|
hlp = self.set_help(item)
|
||||||
Tooltip(cmbframe, text=hlp, wraplength=200)
|
Tooltip(cmbframe, text=hlp, wraplength=200)
|
||||||
|
logger.debug("Built Combo boxes")
|
||||||
|
|
||||||
def opts_checkbuttons(self, frame):
|
def opts_checkbuttons(self, frame):
|
||||||
""" Add the options check buttons """
|
""" Add the options check buttons """
|
||||||
|
logger.debug("Building Check Buttons")
|
||||||
for item in ("raw", "trend", "avg", "outliers"):
|
for item in ("raw", "trend", "avg", "outliers"):
|
||||||
if item == "avg":
|
if item == "avg":
|
||||||
text = "Show Rolling Average"
|
text = "Show Rolling Average"
|
||||||
|
@ -384,9 +452,35 @@ class SessionPopUp(tk.Toplevel):
|
||||||
|
|
||||||
hlp = self.set_help(item)
|
hlp = self.set_help(item)
|
||||||
Tooltip(ctl, text=hlp, wraplength=200)
|
Tooltip(ctl, text=hlp, wraplength=200)
|
||||||
|
logger.debug("Built Check Buttons")
|
||||||
|
|
||||||
|
def opts_loss_keys(self, frame):
|
||||||
|
""" Add loss key selections """
|
||||||
|
logger.debug("Building Loss Key Check Buttons")
|
||||||
|
loss_keys = self.session.loss_keys
|
||||||
|
lk_vars = dict()
|
||||||
|
for loss_key in sorted(loss_keys):
|
||||||
|
text = loss_key.replace("_", " ").title()
|
||||||
|
helptext = "Display {}".format(text)
|
||||||
|
var = tk.BooleanVar()
|
||||||
|
var.set(True)
|
||||||
|
var.trace("w", self.optbtn_reset)
|
||||||
|
lk_vars[loss_key] = var
|
||||||
|
|
||||||
|
if len(loss_keys) == 1:
|
||||||
|
# Don't display if there's only one item
|
||||||
|
break
|
||||||
|
|
||||||
|
ctl = ttk.Checkbutton(frame, variable=var, text=text)
|
||||||
|
ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W)
|
||||||
|
Tooltip(ctl, text=helptext, wraplength=200)
|
||||||
|
|
||||||
|
self.vars["loss_keys"] = lk_vars
|
||||||
|
logger.debug("Built Loss Key Check Buttons")
|
||||||
|
|
||||||
def opts_entry(self, frame):
|
def opts_entry(self, frame):
|
||||||
""" Add the options entry boxes """
|
""" Add the options entry boxes """
|
||||||
|
logger.debug("Building Entry Boxes")
|
||||||
for item in ("avgiterations", ):
|
for item in ("avgiterations", ):
|
||||||
if item == "avgiterations":
|
if item == "avgiterations":
|
||||||
text = "Iterations to Average:"
|
text = "Iterations to Average:"
|
||||||
|
@ -405,27 +499,32 @@ class SessionPopUp(tk.Toplevel):
|
||||||
Tooltip(entframe, text=hlp, wraplength=200)
|
Tooltip(entframe, text=hlp, wraplength=200)
|
||||||
|
|
||||||
self.vars[item] = ctl
|
self.vars[item] = ctl
|
||||||
|
logger.debug("Built Entry Boxes")
|
||||||
|
|
||||||
def opts_buttons(self, frame):
|
def opts_buttons(self, frame):
|
||||||
""" Add the option buttons """
|
""" Add the option buttons """
|
||||||
|
logger.debug("Building Buttons")
|
||||||
btnframe = ttk.Frame(frame)
|
btnframe = ttk.Frame(frame)
|
||||||
btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM)
|
btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM)
|
||||||
|
|
||||||
for btntype in ("reset", "save"):
|
for btntype in ("reset", "save"):
|
||||||
cmd = getattr(self, "optbtn_{}".format(btntype))
|
cmd = getattr(self, "optbtn_{}".format(btntype))
|
||||||
btn = ttk.Button(btnframe,
|
btn = ttk.Button(btnframe,
|
||||||
image=Images().icons[btntype],
|
image=get_images().icons[btntype],
|
||||||
command=cmd)
|
command=cmd)
|
||||||
btn.pack(padx=2, side=tk.RIGHT)
|
btn.pack(padx=2, side=tk.RIGHT)
|
||||||
hlp = self.set_help(btntype)
|
hlp = self.set_help(btntype)
|
||||||
Tooltip(btn, text=hlp, wraplength=200)
|
Tooltip(btn, text=hlp, wraplength=200)
|
||||||
|
logger.debug("Built Buttons")
|
||||||
|
|
||||||
def optbtn_save(self):
|
def optbtn_save(self):
|
||||||
""" Action for save button press """
|
""" Action for save button press """
|
||||||
|
logger.debug("Saving File")
|
||||||
savefile = FileHandler("save", "csv").retfile
|
savefile = FileHandler("save", "csv").retfile
|
||||||
if not savefile:
|
if not savefile:
|
||||||
|
logger.debug("Save Cancelled")
|
||||||
return
|
return
|
||||||
|
logger.debug("Saving to: %s", savefile)
|
||||||
save_data = self.display_data.stats
|
save_data = self.display_data.stats
|
||||||
fieldnames = sorted(key for key in save_data.keys())
|
fieldnames = sorted(key for key in save_data.keys())
|
||||||
|
|
||||||
|
@ -434,16 +533,21 @@ class SessionPopUp(tk.Toplevel):
|
||||||
csvout.writerow(fieldnames)
|
csvout.writerow(fieldnames)
|
||||||
csvout.writerows(zip(*[save_data[key] for key in fieldnames]))
|
csvout.writerows(zip(*[save_data[key] for key in fieldnames]))
|
||||||
|
|
||||||
def optbtn_reset(self, *args):
|
def optbtn_reset(self, *args): # pylint: disable=unused-argument
|
||||||
""" Action for reset button press and checkbox changes"""
|
""" Action for reset button press and checkbox changes"""
|
||||||
|
logger.debug("Refreshing Graph")
|
||||||
if not self.graph_initialised:
|
if not self.graph_initialised:
|
||||||
return
|
return
|
||||||
self.compile_display_data()
|
valid = self.compile_display_data()
|
||||||
|
if not valid:
|
||||||
|
logger.debug("Invalid data")
|
||||||
|
return
|
||||||
self.graph.refresh(self.display_data,
|
self.graph.refresh(self.display_data,
|
||||||
self.vars["display"].get(),
|
self.vars["display"].get(),
|
||||||
self.vars["scale"].get())
|
self.vars["scale"].get())
|
||||||
|
logger.debug("Refreshed Graph")
|
||||||
|
|
||||||
def graph_scale(self, *args):
|
def graph_scale(self, *args): # pylint: disable=unused-argument
|
||||||
""" Action for changing graph scale """
|
""" Action for changing graph scale """
|
||||||
if not self.graph_initialised:
|
if not self.graph_initialised:
|
||||||
return
|
return
|
||||||
|
@ -477,25 +581,53 @@ class SessionPopUp(tk.Toplevel):
|
||||||
|
|
||||||
def compile_display_data(self):
|
def compile_display_data(self):
|
||||||
""" Compile the data to be displayed """
|
""" Compile the data to be displayed """
|
||||||
self.display_data = Calculations(self.data,
|
logger.debug("Compiling Display Data")
|
||||||
self.vars["display"].get(),
|
|
||||||
self.selections_to_list(),
|
loss_keys = [key for key, val in self.vars["loss_keys"].items()
|
||||||
self.vars["avgiterations"].get(),
|
if val.get()]
|
||||||
self.vars["outliers"].get(),
|
logger.debug("Selected loss_keys: %s", loss_keys)
|
||||||
self.is_totals)
|
|
||||||
|
selections = self.selections_to_list()
|
||||||
|
|
||||||
|
if not self.check_valid_selection(loss_keys, selections):
|
||||||
|
return False
|
||||||
|
self.display_data = Calculations(session=self.session,
|
||||||
|
display=self.vars["display"].get(),
|
||||||
|
loss_keys=loss_keys,
|
||||||
|
selections=selections,
|
||||||
|
avg_samples=self.vars["avgiterations"].get(),
|
||||||
|
flatten_outliers=self.vars["outliers"].get(),
|
||||||
|
is_totals=self.is_totals)
|
||||||
|
logger.debug("Compiled Display Data")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_valid_selection(self, loss_keys, selections):
|
||||||
|
""" Check that there will be data to display """
|
||||||
|
display = self.vars["display"].get().lower()
|
||||||
|
logger.debug("Validating selection. (loss_keys: %s, selections: %s, display: %s)",
|
||||||
|
loss_keys, selections, display)
|
||||||
|
if not selections or (display == "loss" and not loss_keys):
|
||||||
|
msg = "No data to display. Not refreshing"
|
||||||
|
logger.debug(msg)
|
||||||
|
print(msg)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def selections_to_list(self):
|
def selections_to_list(self):
|
||||||
""" Compile checkbox selections to list """
|
""" Compile checkbox selections to list """
|
||||||
|
logger.debug("Compiling selections to list")
|
||||||
selections = list()
|
selections = list()
|
||||||
for key, val in self.vars.items():
|
for key, val in self.vars.items():
|
||||||
if (isinstance(val, tk.BooleanVar)
|
if (isinstance(val, tk.BooleanVar)
|
||||||
and key != "outliers"
|
and key != "outliers"
|
||||||
and val.get()):
|
and val.get()):
|
||||||
selections.append(key)
|
selections.append(key)
|
||||||
|
logger.debug("Compiling selections to list: %s", selections)
|
||||||
return selections
|
return selections
|
||||||
|
|
||||||
def graph_build(self, frame):
|
def graph_build(self, frame):
|
||||||
""" Build the graph in the top right paned window """
|
""" Build the graph in the top right paned window """
|
||||||
|
logger.debug("Building Graph")
|
||||||
self.graph = SessionGraph(frame,
|
self.graph = SessionGraph(frame,
|
||||||
self.display_data,
|
self.display_data,
|
||||||
self.vars["display"].get(),
|
self.vars["display"].get(),
|
||||||
|
@ -503,3 +635,4 @@ class SessionPopUp(tk.Toplevel):
|
||||||
self.graph.pack(expand=True, fill=tk.BOTH)
|
self.graph.pack(expand=True, fill=tk.BOTH)
|
||||||
self.graph.build()
|
self.graph.build()
|
||||||
self.graph_initialised = True
|
self.graph_initialised = True
|
||||||
|
logger.debug("Built Graph")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" Command specific tabs of Display Frame of the Faceswap GUI """
|
""" Command specific tabs of Display Frame of the Faceswap GUI """
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
|
|
||||||
|
@ -11,19 +12,23 @@ from .display_graph import TrainingGraph
|
||||||
from .display_page import DisplayOptionalPage
|
from .display_page import DisplayOptionalPage
|
||||||
from .tooltip import Tooltip
|
from .tooltip import Tooltip
|
||||||
from .stats import Calculations
|
from .stats import Calculations
|
||||||
from .utils import Images, FileHandler
|
from .utils import FileHandler, get_config, get_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class PreviewExtract(DisplayOptionalPage):
|
class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||||
""" Tab to display output preview images for extract and convert """
|
""" Tab to display output preview images for extract and convert """
|
||||||
|
|
||||||
def display_item_set(self):
|
def display_item_set(self):
|
||||||
""" Load the latest preview if available """
|
""" Load the latest preview if available """
|
||||||
Images().load_latest_preview()
|
logger.trace("Loading latest preview")
|
||||||
self.display_item = Images().previewoutput
|
get_images().load_latest_preview()
|
||||||
|
self.display_item = get_images().previewoutput
|
||||||
|
|
||||||
def display_item_process(self):
|
def display_item_process(self):
|
||||||
""" Display the preview """
|
""" Display the preview """
|
||||||
|
logger.trace("Displaying preview")
|
||||||
if not self.subnotebook.children:
|
if not self.subnotebook.children:
|
||||||
self.add_child()
|
self.add_child()
|
||||||
else:
|
else:
|
||||||
|
@ -31,15 +36,17 @@ class PreviewExtract(DisplayOptionalPage):
|
||||||
|
|
||||||
def add_child(self):
|
def add_child(self):
|
||||||
""" Add the preview label child """
|
""" Add the preview label child """
|
||||||
|
logger.debug("Adding child")
|
||||||
preview = self.subnotebook_add_page(self.tabname, widget=None)
|
preview = self.subnotebook_add_page(self.tabname, widget=None)
|
||||||
lblpreview = ttk.Label(preview, image=Images().previewoutput[1])
|
lblpreview = ttk.Label(preview, image=get_images().previewoutput[1])
|
||||||
lblpreview.pack(side=tk.TOP, anchor=tk.NW)
|
lblpreview.pack(side=tk.TOP, anchor=tk.NW)
|
||||||
Tooltip(lblpreview, text=self.helptext, wraplength=200)
|
Tooltip(lblpreview, text=self.helptext, wraplength=200)
|
||||||
|
|
||||||
def update_child(self):
|
def update_child(self):
|
||||||
""" Update the preview image on the label """
|
""" Update the preview image on the label """
|
||||||
|
logger.trace("Updating preview")
|
||||||
for widget in self.subnotebook_get_widgets():
|
for widget in self.subnotebook_get_widgets():
|
||||||
widget.configure(image=Images().previewoutput[1])
|
widget.configure(image=get_images().previewoutput[1])
|
||||||
|
|
||||||
def save_items(self):
|
def save_items(self):
|
||||||
""" Open save dialogue and save preview """
|
""" Open save dialogue and save preview """
|
||||||
|
@ -52,41 +59,56 @@ class PreviewExtract(DisplayOptionalPage):
|
||||||
"{}_{}.{}".format(filename,
|
"{}_{}.{}".format(filename,
|
||||||
now,
|
now,
|
||||||
"png"))
|
"png"))
|
||||||
Images().previewoutput[0].save(filename)
|
get_images().previewoutput[0].save(filename)
|
||||||
|
logger.debug("Saved preview to %s", filename)
|
||||||
print("Saved preview to {}".format(filename))
|
print("Saved preview to {}".format(filename))
|
||||||
|
|
||||||
|
|
||||||
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||||
""" Training preview image(s) """
|
""" Training preview image(s) """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.update_preview = get_config().tk_vars["updatepreview"]
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def display_item_set(self):
|
def display_item_set(self):
|
||||||
""" Load the latest preview if available """
|
""" Load the latest preview if available """
|
||||||
Images().load_training_preview()
|
logger.trace("Loading latest preview")
|
||||||
self.display_item = Images().previewtrain
|
if not self.update_preview.get():
|
||||||
|
logger.trace("Preview not updated")
|
||||||
|
return
|
||||||
|
get_images().load_training_preview()
|
||||||
|
self.display_item = get_images().previewtrain
|
||||||
|
|
||||||
def display_item_process(self):
|
def display_item_process(self):
|
||||||
""" Display the preview(s) resized as appropriate """
|
""" Display the preview(s) resized as appropriate """
|
||||||
sortednames = sorted([name for name in Images().previewtrain.keys()])
|
logger.trace("Displaying preview")
|
||||||
|
sortednames = sorted(list(get_images().previewtrain.keys()))
|
||||||
existing = self.subnotebook_get_titles_ids()
|
existing = self.subnotebook_get_titles_ids()
|
||||||
|
should_update = self.update_preview.get()
|
||||||
|
|
||||||
for name in sortednames:
|
for name in sortednames:
|
||||||
if name not in existing.keys():
|
if name not in existing.keys():
|
||||||
self.add_child(name)
|
self.add_child(name)
|
||||||
else:
|
elif should_update:
|
||||||
tab_id = existing[name]
|
tab_id = existing[name]
|
||||||
self.update_child(tab_id, name)
|
self.update_child(tab_id, name)
|
||||||
|
|
||||||
|
if should_update:
|
||||||
|
self.update_preview.set(False)
|
||||||
|
|
||||||
def add_child(self, name):
|
def add_child(self, name):
|
||||||
""" Add the preview canvas child """
|
""" Add the preview canvas child """
|
||||||
|
logger.debug("Adding child")
|
||||||
preview = PreviewTrainCanvas(self.subnotebook, name)
|
preview = PreviewTrainCanvas(self.subnotebook, name)
|
||||||
preview = self.subnotebook_add_page(name, widget=preview)
|
preview = self.subnotebook_add_page(name, widget=preview)
|
||||||
Tooltip(preview, text=self.helptext, wraplength=200)
|
Tooltip(preview, text=self.helptext, wraplength=200)
|
||||||
self.vars["modified"].set(Images().previewtrain[name][2])
|
self.vars["modified"].set(get_images().previewtrain[name][2])
|
||||||
|
|
||||||
def update_child(self, tab_id, name):
|
def update_child(self, tab_id, name):
|
||||||
""" Update the preview canvas """
|
""" Update the preview canvas """
|
||||||
if self.vars["modified"].get() != Images().previewtrain[name][2]:
|
logger.debug("Updating preview")
|
||||||
self.vars["modified"].set(Images().previewtrain[name][2])
|
if self.vars["modified"].get() != get_images().previewtrain[name][2]:
|
||||||
|
self.vars["modified"].set(get_images().previewtrain[name][2])
|
||||||
widget = self.subnotebook_page_from_id(tab_id)
|
widget = self.subnotebook_page_from_id(tab_id)
|
||||||
widget.reload()
|
widget.reload()
|
||||||
|
|
||||||
|
@ -102,11 +124,12 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||||
class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
|
class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" Canvas to hold a training preview image """
|
""" Canvas to hold a training preview image """
|
||||||
def __init__(self, parent, previewname):
|
def __init__(self, parent, previewname):
|
||||||
|
logger.debug("Initializing %s: (previewname: '%s')", self.__class__.__name__, previewname)
|
||||||
ttk.Frame.__init__(self, parent)
|
ttk.Frame.__init__(self, parent)
|
||||||
|
|
||||||
self.name = previewname
|
self.name = previewname
|
||||||
Images().resize_image(self.name, None)
|
get_images().resize_image(self.name, None)
|
||||||
self.previewimage = Images().previewtrain[self.name][1]
|
self.previewimage = get_images().previewtrain[self.name][1]
|
||||||
|
|
||||||
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
|
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
|
||||||
self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
@ -115,18 +138,21 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
image=self.previewimage,
|
image=self.previewimage,
|
||||||
anchor=tk.NW)
|
anchor=tk.NW)
|
||||||
self.bind("<Configure>", self.resize)
|
self.bind("<Configure>", self.resize)
|
||||||
|
logger.debug("Initialized %s:", self.__class__.__name__)
|
||||||
|
|
||||||
def resize(self, event):
|
def resize(self, event):
|
||||||
""" Resize the image to fit the frame, maintaining aspect ratio """
|
""" Resize the image to fit the frame, maintaining aspect ratio """
|
||||||
|
logger.trace("Resizing preview image")
|
||||||
framesize = (event.width, event.height)
|
framesize = (event.width, event.height)
|
||||||
# Sometimes image is resized before frame is drawn
|
# Sometimes image is resized before frame is drawn
|
||||||
framesize = None if framesize == (1, 1) else framesize
|
framesize = None if framesize == (1, 1) else framesize
|
||||||
Images().resize_image(self.name, framesize)
|
get_images().resize_image(self.name, framesize)
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def reload(self):
|
def reload(self):
|
||||||
""" Reload the preview image """
|
""" Reload the preview image """
|
||||||
self.previewimage = Images().previewtrain[self.name][1]
|
logger.trace("Reloading preview image")
|
||||||
|
self.previewimage = get_images().previewtrain[self.name][1]
|
||||||
self.canvas.itemconfig(self.imgcanvas, image=self.previewimage)
|
self.canvas.itemconfig(self.imgcanvas, image=self.previewimage)
|
||||||
|
|
||||||
def save_preview(self, location):
|
def save_preview(self, location):
|
||||||
|
@ -137,40 +163,63 @@ class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
"{}_{}.{}".format(filename,
|
"{}_{}.{}".format(filename,
|
||||||
now,
|
now,
|
||||||
"png"))
|
"png"))
|
||||||
Images().previewtrain[self.name][0].save(filename)
|
get_images().previewtrain[self.name][0].save(filename)
|
||||||
|
logger.debug("Saved preview to %s", filename)
|
||||||
print("Saved preview to {}".format(filename))
|
print("Saved preview to {}".format(filename))
|
||||||
|
|
||||||
|
|
||||||
class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||||
""" The Graph Tab of the Display section """
|
""" The Graph Tab of the Display section """
|
||||||
|
|
||||||
|
def add_options(self):
|
||||||
|
""" Add the additional options """
|
||||||
|
self.add_option_refresh()
|
||||||
|
super().add_options()
|
||||||
|
|
||||||
|
def add_option_refresh(self):
|
||||||
|
""" Add refresh button to refresh graph immediately """
|
||||||
|
logger.debug("Adding refresh option")
|
||||||
|
tk_var = get_config().tk_vars["refreshgraph"]
|
||||||
|
btnrefresh = ttk.Button(self.optsframe,
|
||||||
|
image=get_images().icons["reset"],
|
||||||
|
command=lambda: tk_var.set(True))
|
||||||
|
btnrefresh.pack(padx=2, side=tk.RIGHT)
|
||||||
|
Tooltip(btnrefresh,
|
||||||
|
text="Graph updates every 100 iterations. Click to refresh now.",
|
||||||
|
wraplength=200)
|
||||||
|
|
||||||
def display_item_set(self):
|
def display_item_set(self):
|
||||||
""" Load the graph(s) if available """
|
""" Load the graph(s) if available """
|
||||||
if self.session.stats["iterations"] == 0:
|
session = get_config().session
|
||||||
|
if session.initialized and session.logging_disabled:
|
||||||
|
logger.trace("Logs disabled. Hiding graph")
|
||||||
|
self.set_info("Graph is disabled as 'no-logs' has been selected")
|
||||||
self.display_item = None
|
self.display_item = None
|
||||||
|
elif session.initialized:
|
||||||
|
logger.trace("Loading graph")
|
||||||
|
self.display_item = session
|
||||||
else:
|
else:
|
||||||
self.display_item = self.session.stats
|
self.display_item = None
|
||||||
|
|
||||||
def display_item_process(self):
|
def display_item_process(self):
|
||||||
""" Add a single graph to the graph window """
|
""" Add a single graph to the graph window """
|
||||||
losskeys = self.display_item["losskeys"]
|
logger.trace("Adding graph")
|
||||||
loss = self.display_item["loss"]
|
existing = list(self.subnotebook_get_titles_ids().keys())
|
||||||
tabcount = int(len(losskeys) / 2)
|
|
||||||
existing = self.subnotebook_get_titles_ids()
|
for loss_key in self.display_item.loss_keys:
|
||||||
for i in range(tabcount):
|
tabname = loss_key.replace("_", " ").title()
|
||||||
selectedkeys = losskeys[i * 2:(i + 1) * 2]
|
if tabname in existing:
|
||||||
name = " - ".join(selectedkeys).title().replace("_", " ")
|
continue
|
||||||
if name not in existing.keys():
|
|
||||||
selectedloss = loss[i * 2:(i + 1) * 2]
|
data = Calculations(session=get_config().session,
|
||||||
selection = {"loss": selectedloss,
|
display="loss",
|
||||||
"losskeys": selectedkeys}
|
loss_keys=[loss_key],
|
||||||
data = Calculations(session=selection,
|
selections=["raw", "trend"])
|
||||||
display="loss",
|
self.add_child(tabname, data)
|
||||||
selections=["raw", "trend"])
|
|
||||||
self.add_child(name, data)
|
|
||||||
|
|
||||||
def add_child(self, name, data):
|
def add_child(self, name, data):
|
||||||
""" Add the graph for the selected keys """
|
""" Add the graph for the selected keys """
|
||||||
|
logger.debug("Adding child: %s", name)
|
||||||
graph = TrainingGraph(self.subnotebook, data, "Loss")
|
graph = TrainingGraph(self.subnotebook, data, "Loss")
|
||||||
graph.build()
|
graph.build()
|
||||||
graph = self.subnotebook_add_page(name, widget=graph)
|
graph = self.subnotebook_add_page(name, widget=graph)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" Graph functions for Display Frame of the Faceswap GUI """
|
""" Graph functions for Display Frame of the Faceswap GUI """
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
|
|
||||||
|
@ -8,14 +9,17 @@ from tkinter import ttk
|
||||||
from math import ceil, floor
|
from math import ceil, floor
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
# pylint: disable=wrong-import-position
|
||||||
matplotlib.use("TkAgg")
|
matplotlib.use("TkAgg")
|
||||||
import matplotlib.animation as animation
|
|
||||||
from matplotlib import pyplot as plt, style
|
|
||||||
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
|
|
||||||
NavigationToolbar2Tk)
|
|
||||||
|
|
||||||
from .tooltip import Tooltip
|
from matplotlib import pyplot as plt, style # noqa
|
||||||
from .utils import Images
|
from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
|
||||||
|
NavigationToolbar2Tk) # noqa
|
||||||
|
|
||||||
|
from .tooltip import Tooltip # noqa
|
||||||
|
from .utils import get_config, get_images # noqa
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors
|
class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors
|
||||||
|
@ -26,28 +30,24 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
||||||
t[0] in ("Home", "Pan", "Zoom", "Save")]
|
t[0] in ("Home", "Pan", "Zoom", "Save")]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _Button(frame, text, file, command, extension=".gif"):
|
def _Button(frame, text, file, command, extension=".gif"): # pylint: disable=arguments-differ
|
||||||
""" Map Buttons to their own frame.
|
""" Map Buttons to their own frame.
|
||||||
Use custom button icons,
|
Use custom button icons, Use ttk buttons pack to the right """
|
||||||
Use ttk buttons
|
|
||||||
pack to the right """
|
|
||||||
iconmapping = {"home": "reset",
|
iconmapping = {"home": "reset",
|
||||||
"filesave": "save",
|
"filesave": "save",
|
||||||
"zoom_to_rect": "zoom"}
|
"zoom_to_rect": "zoom"}
|
||||||
icon = iconmapping[file] if iconmapping.get(file, None) else file
|
icon = iconmapping[file] if iconmapping.get(file, None) else file
|
||||||
img = Images().icons[icon]
|
img = get_images().icons[icon]
|
||||||
btn = ttk.Button(frame, text=text, image=img, command=command)
|
btn = ttk.Button(frame, text=text, image=img, command=command)
|
||||||
btn.pack(side=tk.RIGHT, padx=2)
|
btn.pack(side=tk.RIGHT, padx=2)
|
||||||
return btn
|
return btn
|
||||||
|
|
||||||
def _init_toolbar(self):
|
def _init_toolbar(self):
|
||||||
""" Same as original but ttk widgets and standard
|
""" Same as original but ttk widgets and standard tooltips used. Separator added and
|
||||||
tooltips used. Separator added and message label
|
message label packed to the left """
|
||||||
packed to the left """
|
|
||||||
xmin, xmax = self.canvas.figure.bbox.intervalx
|
xmin, xmax = self.canvas.figure.bbox.intervalx
|
||||||
height, width = 50, xmax-xmin
|
height, width = 50, xmax-xmin
|
||||||
ttk.Frame.__init__(self, master=self.window,
|
ttk.Frame.__init__(self, master=self.window, width=int(width), height=int(height))
|
||||||
width=int(width), height=int(height))
|
|
||||||
|
|
||||||
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
|
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
|
||||||
sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP)
|
sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP)
|
||||||
|
@ -76,14 +76,14 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
||||||
class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" Base class for matplotlib line graphs """
|
""" Base class for matplotlib line graphs """
|
||||||
def __init__(self, parent, data, ylabel):
|
def __init__(self, parent, data, ylabel):
|
||||||
ttk.Frame.__init__(self, parent)
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
super().__init__(parent)
|
||||||
style.use("ggplot")
|
style.use("ggplot")
|
||||||
|
|
||||||
self.calcs = data
|
self.calcs = data
|
||||||
self.ylabel = ylabel
|
self.ylabel = ylabel
|
||||||
self.colourmaps = ["Reds", "Blues", "Greens",
|
self.colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges",
|
||||||
"Purples", "Oranges", "Greys",
|
"Greys", "copper", "summer", "bone"]
|
||||||
"copper", "summer", "bone"]
|
|
||||||
self.lines = list()
|
self.lines = list()
|
||||||
self.toolbar = None
|
self.toolbar = None
|
||||||
self.fig = plt.figure(figsize=(4, 4), dpi=75)
|
self.fig = plt.figure(figsize=(4, 4), dpi=75)
|
||||||
|
@ -92,26 +92,29 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
self.initiate_graph()
|
self.initiate_graph()
|
||||||
self.update_plot(initiate=True)
|
self.update_plot(initiate=True)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def initiate_graph(self):
|
def initiate_graph(self):
|
||||||
""" Place the graph canvas """
|
""" Place the graph canvas """
|
||||||
self.plotcanvas.get_tk_widget().pack(side=tk.TOP,
|
logger.debug("Setting plotcanvas")
|
||||||
padx=5,
|
self.plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True)
|
||||||
fill=tk.BOTH,
|
|
||||||
expand=True)
|
|
||||||
plt.subplots_adjust(left=0.100,
|
plt.subplots_adjust(left=0.100,
|
||||||
bottom=0.100,
|
bottom=0.100,
|
||||||
right=0.95,
|
right=0.95,
|
||||||
top=0.95,
|
top=0.95,
|
||||||
wspace=0.2,
|
wspace=0.2,
|
||||||
hspace=0.2)
|
hspace=0.2)
|
||||||
|
logger.debug("Set plotcanvas")
|
||||||
|
|
||||||
def update_plot(self, initiate=True):
|
def update_plot(self, initiate=True):
|
||||||
""" Update the plot with incoming data """
|
""" Update the plot with incoming data """
|
||||||
|
logger.trace("Updating plot")
|
||||||
if initiate:
|
if initiate:
|
||||||
|
logger.debug("Initializing plot")
|
||||||
self.lines = list()
|
self.lines = list()
|
||||||
self.ax1.clear()
|
self.ax1.clear()
|
||||||
self.axes_labels_set()
|
self.axes_labels_set()
|
||||||
|
logger.debug("Initialized plot")
|
||||||
|
|
||||||
fulldata = [item for item in self.calcs.stats.values()]
|
fulldata = [item for item in self.calcs.stats.values()]
|
||||||
self.axes_limits_set(fulldata)
|
self.axes_limits_set(fulldata)
|
||||||
|
@ -120,37 +123,37 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
keys = list(self.calcs.stats.keys())
|
keys = list(self.calcs.stats.keys())
|
||||||
for idx, item in enumerate(self.lines_sort(keys)):
|
for idx, item in enumerate(self.lines_sort(keys)):
|
||||||
if initiate:
|
if initiate:
|
||||||
self.lines.extend(self.ax1.plot(xrng,
|
self.lines.extend(self.ax1.plot(xrng, self.calcs.stats[item[0]],
|
||||||
self.calcs.stats[item[0]],
|
label=item[1], linewidth=item[2], color=item[3]))
|
||||||
label=item[1],
|
|
||||||
linewidth=item[2],
|
|
||||||
color=item[3]))
|
|
||||||
else:
|
else:
|
||||||
self.lines[idx].set_data(xrng, self.calcs.stats[item[0]])
|
self.lines[idx].set_data(xrng, self.calcs.stats[item[0]])
|
||||||
|
|
||||||
if initiate:
|
if initiate:
|
||||||
self.legend_place()
|
self.legend_place()
|
||||||
|
logger.trace("Updated plot")
|
||||||
|
|
||||||
def axes_labels_set(self):
|
def axes_labels_set(self):
|
||||||
""" Set the axes label and range """
|
""" Set the axes label and range """
|
||||||
|
logger.debug("Setting axes labels. y-label: '%s'", self.ylabel)
|
||||||
self.ax1.set_xlabel("Iterations")
|
self.ax1.set_xlabel("Iterations")
|
||||||
self.ax1.set_ylabel(self.ylabel)
|
self.ax1.set_ylabel(self.ylabel)
|
||||||
|
|
||||||
def axes_limits_set_default(self):
|
def axes_limits_set_default(self):
|
||||||
""" Set default axes limits """
|
""" Set default axes limits """
|
||||||
|
logger.debug("Setting default axes ranges")
|
||||||
self.ax1.set_ylim(0.00, 100.0)
|
self.ax1.set_ylim(0.00, 100.0)
|
||||||
self.ax1.set_xlim(0, 1)
|
self.ax1.set_xlim(0, 1)
|
||||||
|
|
||||||
def axes_limits_set(self, data):
|
def axes_limits_set(self, data):
|
||||||
""" Set the axes limits """
|
""" Set the axes limits """
|
||||||
xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1
|
xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
ymin, ymax = self.axes_data_get_min_max(data)
|
ymin, ymax = self.axes_data_get_min_max(data)
|
||||||
self.ax1.set_ylim(ymin, ymax)
|
self.ax1.set_ylim(ymin, ymax)
|
||||||
self.ax1.set_xlim(0, xmax)
|
self.ax1.set_xlim(0, xmax)
|
||||||
else:
|
else:
|
||||||
self.axes_limits_set_default()
|
self.axes_limits_set_default()
|
||||||
|
logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def axes_data_get_min_max(data):
|
def axes_data_get_min_max(data):
|
||||||
|
@ -164,15 +167,18 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
ymax.append(max(dataset) * 1000)
|
ymax.append(max(dataset) * 1000)
|
||||||
ymin = floor(min(ymin)) / 1000
|
ymin = floor(min(ymin)) / 1000
|
||||||
ymax = ceil(max(ymax)) / 1000
|
ymax = ceil(max(ymax)) / 1000
|
||||||
|
logger.trace("ymin: %s, ymax: %s", ymin, ymax)
|
||||||
return ymin, ymax
|
return ymin, ymax
|
||||||
|
|
||||||
def axes_set_yscale(self, scale):
|
def axes_set_yscale(self, scale):
|
||||||
""" Set the Y-Scale to log or linear """
|
""" Set the Y-Scale to log or linear """
|
||||||
|
logger.debug("yscale: '%s'", scale)
|
||||||
self.ax1.set_yscale(scale)
|
self.ax1.set_yscale(scale)
|
||||||
|
|
||||||
def lines_sort(self, keys):
|
def lines_sort(self, keys):
|
||||||
""" Sort the data keys into consistent order
|
""" Sort the data keys into consistent order
|
||||||
and set line colourmap and line width """
|
and set line color map and line width """
|
||||||
|
logger.trace("Sorting lines")
|
||||||
raw_lines = list()
|
raw_lines = list()
|
||||||
sorted_lines = list()
|
sorted_lines = list()
|
||||||
for key in sorted(keys):
|
for key in sorted(keys):
|
||||||
|
@ -184,29 +190,28 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
groupsize = self.lines_groupsize(raw_lines, sorted_lines)
|
groupsize = self.lines_groupsize(raw_lines, sorted_lines)
|
||||||
sorted_lines = raw_lines + sorted_lines
|
sorted_lines = raw_lines + sorted_lines
|
||||||
|
|
||||||
lines = self.lines_style(sorted_lines, groupsize)
|
lines = self.lines_style(sorted_lines, groupsize)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lines_groupsize(raw_lines, sorted_lines):
|
def lines_groupsize(raw_lines, sorted_lines):
|
||||||
""" Get the number of items in each group.
|
""" Get the number of items in each group.
|
||||||
If raw data isn't selected, then check
|
If raw data isn't selected, then check the length of
|
||||||
the length of remaining groups until
|
remaining groups until something is found """
|
||||||
something is found """
|
|
||||||
groupsize = 1
|
groupsize = 1
|
||||||
if raw_lines:
|
if raw_lines:
|
||||||
groupsize = len(raw_lines)
|
groupsize = len(raw_lines)
|
||||||
else:
|
else:
|
||||||
for check in ("avg", "trend"):
|
for check in ("avg", "trend"):
|
||||||
if any(item[0].startswith(check) for item in sorted_lines):
|
if any(item[0].startswith(check) for item in sorted_lines):
|
||||||
groupsize = len([item for item in sorted_lines
|
groupsize = len([item for item in sorted_lines if item[0].startswith(check)])
|
||||||
if item[0].startswith(check)])
|
|
||||||
break
|
break
|
||||||
|
logger.trace(groupsize)
|
||||||
return groupsize
|
return groupsize
|
||||||
|
|
||||||
def lines_style(self, lines, groupsize):
|
def lines_style(self, lines, groupsize):
|
||||||
""" Set the colourmap and linewidth for each group """
|
""" Set the color map and line width for each group """
|
||||||
|
logger.trace("Setting lines style")
|
||||||
groups = int(len(lines) / groupsize)
|
groups = int(len(lines) / groupsize)
|
||||||
colours = self.lines_create_colors(groupsize, groups)
|
colours = self.lines_create_colors(groupsize, groups)
|
||||||
for idx, item in enumerate(lines):
|
for idx, item in enumerate(lines):
|
||||||
|
@ -215,21 +220,24 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def lines_create_colors(self, groupsize, groups):
|
def lines_create_colors(self, groupsize, groups):
|
||||||
""" Create the colours """
|
""" Create the colors """
|
||||||
colours = list()
|
colours = list()
|
||||||
for i in range(1, groups + 1):
|
for i in range(1, groups + 1):
|
||||||
for colour in self.colourmaps[0:groupsize]:
|
for colour in self.colourmaps[0:groupsize]:
|
||||||
cmap = matplotlib.cm.get_cmap(colour)
|
cmap = matplotlib.cm.get_cmap(colour)
|
||||||
cpoint = 1 - (i / 5)
|
cpoint = 1 - (i / 5)
|
||||||
colours.append(cmap(cpoint))
|
colours.append(cmap(cpoint))
|
||||||
|
logger.trace(colours)
|
||||||
return colours
|
return colours
|
||||||
|
|
||||||
def legend_place(self):
|
def legend_place(self):
|
||||||
""" Place and format legend """
|
""" Place and format legend """
|
||||||
|
logger.debug("Placing legend")
|
||||||
self.ax1.legend(loc="upper right", ncol=2)
|
self.ax1.legend(loc="upper right", ncol=2)
|
||||||
|
|
||||||
def toolbar_place(self, parent):
|
def toolbar_place(self, parent):
|
||||||
""" Add Graph Navigation toolbar """
|
""" Add Graph Navigation toolbar """
|
||||||
|
logger.debug("Placing toolbar")
|
||||||
self.toolbar = NavigationToolbar(self.plotcanvas, parent)
|
self.toolbar = NavigationToolbar(self.plotcanvas, parent)
|
||||||
self.toolbar.pack(side=tk.BOTTOM)
|
self.toolbar.pack(side=tk.BOTTOM)
|
||||||
self.toolbar.update()
|
self.toolbar.update()
|
||||||
|
@ -240,72 +248,48 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
def __init__(self, parent, data, ylabel):
|
def __init__(self, parent, data, ylabel):
|
||||||
GraphBase.__init__(self, parent, data, ylabel)
|
GraphBase.__init__(self, parent, data, ylabel)
|
||||||
|
self.add_callback()
|
||||||
|
|
||||||
self.anim = None
|
def add_callback(self):
|
||||||
|
""" Add the variable trace to update graph on recent button or save iteration """
|
||||||
|
get_config().tk_vars["refreshgraph"].trace("w", self.refresh)
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
""" Update the plot area with loss values and cycle through to
|
""" Update the plot area with loss values """
|
||||||
animate """
|
logger.debug("Building training graph")
|
||||||
self.anim = animation.FuncAnimation(self.fig,
|
|
||||||
self.animate,
|
|
||||||
interval=200,
|
|
||||||
blit=False)
|
|
||||||
self.plotcanvas.draw()
|
self.plotcanvas.draw()
|
||||||
|
logger.debug("Built training graph")
|
||||||
|
|
||||||
def animate(self, i):
|
def refresh(self, *args): # pylint: disable=unused-argument
|
||||||
""" Read loss data and apply to graph """
|
""" Read loss data and apply to graph """
|
||||||
|
logger.debug("Updating plot")
|
||||||
self.calcs.refresh()
|
self.calcs.refresh()
|
||||||
self.update_plot(initiate=False)
|
self.update_plot(initiate=False)
|
||||||
|
self.plotcanvas.draw()
|
||||||
def set_animation_rate(self, iterations):
|
get_config().tk_vars["refreshgraph"].set(False)
|
||||||
""" Change the animation update interval based on how
|
|
||||||
many iterations have been
|
|
||||||
There's no point calculating a graph over thousands of
|
|
||||||
points of data when the change will be miniscule """
|
|
||||||
if iterations > 30000:
|
|
||||||
speed = 60000 # 1 min updates
|
|
||||||
elif iterations > 20000:
|
|
||||||
speed = 30000 # 30 sec updates
|
|
||||||
elif iterations > 10000:
|
|
||||||
speed = 10000 # 10 sec updates
|
|
||||||
elif iterations > 5000:
|
|
||||||
speed = 5000 # 5 sec updates
|
|
||||||
elif iterations > 1000:
|
|
||||||
speed = 2000 # 2 sec updates
|
|
||||||
elif iterations > 500:
|
|
||||||
speed = 1000 # 1 sec updates
|
|
||||||
elif iterations > 100:
|
|
||||||
speed = 500 # 0.5 sec updates
|
|
||||||
else:
|
|
||||||
speed = 200 # 200ms updates
|
|
||||||
if not self.anim.event_source.interval == speed:
|
|
||||||
self.anim.event_source.interval = speed
|
|
||||||
|
|
||||||
def save_fig(self, location):
|
def save_fig(self, location):
|
||||||
""" Save the figure to file """
|
""" Save the figure to file """
|
||||||
keys = sorted([key.replace("raw_", "")
|
logger.debug("Saving graph: '%s'", location)
|
||||||
for key in self.calcs.stats.keys()
|
keys = sorted([key.replace("raw_", "") for key in self.calcs.stats.keys()
|
||||||
if key.startswith("raw_")])
|
if key.startswith("raw_")])
|
||||||
filename = " - ".join(keys)
|
filename = " - ".join(keys)
|
||||||
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = os.path.join(location,
|
filename = os.path.join(location, "{}_{}.{}".format(filename, now, "png"))
|
||||||
"{}_{}.{}".format(filename,
|
|
||||||
now,
|
|
||||||
"png"))
|
|
||||||
self.fig.set_size_inches(16, 9)
|
self.fig.set_size_inches(16, 9)
|
||||||
self.fig.savefig(filename, bbox_inches="tight", dpi=120)
|
self.fig.savefig(filename, bbox_inches="tight", dpi=120)
|
||||||
print("Saved graph to {}".format(filename))
|
print("Saved graph to {}".format(filename))
|
||||||
|
logger.debug("Saved graph: '%s'", filename)
|
||||||
self.resize_fig()
|
self.resize_fig()
|
||||||
|
|
||||||
def resize_fig(self):
|
def resize_fig(self):
|
||||||
""" Resize the figure back to the canvas """
|
""" Resize the figure back to the canvas """
|
||||||
class Event():
|
class Event(): # pylint: disable=too-few-public-methods
|
||||||
""" Event class that needs to be passed to
|
""" Event class that needs to be passed to plotcanvas.resize """
|
||||||
plotcanvas.resize """
|
|
||||||
pass
|
pass
|
||||||
Event.width = self.winfo_width()
|
Event.width = self.winfo_width()
|
||||||
Event.height = self.winfo_height()
|
Event.height = self.winfo_height()
|
||||||
self.plotcanvas.resize(Event)
|
self.plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter
|
||||||
|
|
||||||
|
|
||||||
class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
|
class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
|
||||||
|
@ -316,18 +300,24 @@ class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
""" Build the session graph """
|
""" Build the session graph """
|
||||||
|
logger.debug("Building session graph")
|
||||||
self.toolbar_place(self)
|
self.toolbar_place(self)
|
||||||
self.plotcanvas.draw()
|
self.plotcanvas.draw()
|
||||||
|
logger.debug("Built session graph")
|
||||||
|
|
||||||
def refresh(self, data, ylabel, scale):
|
def refresh(self, data, ylabel, scale):
|
||||||
""" Refresh graph data """
|
""" Refresh graph data """
|
||||||
|
logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale)
|
||||||
self.calcs = data
|
self.calcs = data
|
||||||
self.ylabel = ylabel
|
self.ylabel = ylabel
|
||||||
self.set_yscale_type(scale)
|
self.set_yscale_type(scale)
|
||||||
|
logger.debug("Refreshed session graph")
|
||||||
|
|
||||||
def set_yscale_type(self, scale):
|
def set_yscale_type(self, scale):
|
||||||
""" switch the y-scale and redraw """
|
""" switch the y-scale and redraw """
|
||||||
|
logger.debug("Updating scale type: '%s'", scale)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.update_plot(initiate=True)
|
self.update_plot(initiate=True)
|
||||||
self.axes_set_yscale(self.scale)
|
self.axes_set_yscale(self.scale)
|
||||||
self.plotcanvas.draw()
|
self.plotcanvas.draw()
|
||||||
|
logger.debug("Updated scale type")
|
||||||
|
|
|
@ -1,21 +1,25 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" Display Page parent classes for display section of the Faceswap GUI """
|
""" Display Page parent classes for display section of the Faceswap GUI """
|
||||||
|
|
||||||
|
import logging
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
from .tooltip import Tooltip
|
from .tooltip import Tooltip
|
||||||
from .utils import Images
|
from .utils import get_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DisplayPage(ttk.Frame):
|
class DisplayPage(ttk.Frame):
|
||||||
""" Parent frame holder for each tab.
|
""" Parent frame holder for each tab.
|
||||||
Defines uniform structure for each tab to inherit from """
|
Defines uniform structure for each tab to inherit from """
|
||||||
def __init__(self, parent, tabname, helptext):
|
def __init__(self, parent, tabname, helptext):
|
||||||
|
logger.debug("Initializing %s: (tabname: '%s', helptext: %s",
|
||||||
|
self.__class__.__name__, tabname, helptext)
|
||||||
ttk.Frame.__init__(self, parent)
|
ttk.Frame.__init__(self, parent)
|
||||||
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
|
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
|
||||||
|
|
||||||
self.session = parent.session
|
|
||||||
self.runningtask = parent.runningtask
|
self.runningtask = parent.runningtask
|
||||||
self.helptext = helptext
|
self.helptext = helptext
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
|
@ -30,32 +34,37 @@ class DisplayPage(ttk.Frame):
|
||||||
self.add_frame_separator()
|
self.add_frame_separator()
|
||||||
self.set_mainframe_single_tab_style()
|
self.set_mainframe_single_tab_style()
|
||||||
parent.add(self, text=self.tabname.title())
|
parent.add(self, text=self.tabname.title())
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__,)
|
||||||
|
|
||||||
def add_optional_vars(self, varsdict):
|
def add_optional_vars(self, varsdict):
|
||||||
""" Add page specific variables """
|
""" Add page specific variables """
|
||||||
if isinstance(varsdict, dict):
|
if isinstance(varsdict, dict):
|
||||||
for key, val in varsdict.items():
|
for key, val in varsdict.items():
|
||||||
|
logger.debug("Adding: (%s: %s)", key, val)
|
||||||
self.vars[key] = val
|
self.vars[key] = val
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_vars():
|
def set_vars():
|
||||||
""" Overide to return a dict of page specific variables """
|
""" Override to return a dict of page specific variables """
|
||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
def add_subnotebook(self):
|
def add_subnotebook(self):
|
||||||
""" Add the main frame notebook """
|
""" Add the main frame notebook """
|
||||||
|
logger.debug("Adding subnotebook")
|
||||||
notebook = ttk.Notebook(self)
|
notebook = ttk.Notebook(self)
|
||||||
notebook.pack(side=tk.TOP, anchor=tk.NW, fill=tk.BOTH, expand=True)
|
notebook.pack(side=tk.TOP, anchor=tk.NW, fill=tk.BOTH, expand=True)
|
||||||
return notebook
|
return notebook
|
||||||
|
|
||||||
def add_options_frame(self):
|
def add_options_frame(self):
|
||||||
""" Add the display tab options """
|
""" Add the display tab options """
|
||||||
|
logger.debug("Adding options frame")
|
||||||
optsframe = ttk.Frame(self)
|
optsframe = ttk.Frame(self)
|
||||||
optsframe.pack(side=tk.BOTTOM, padx=5, pady=5, fill=tk.X)
|
optsframe.pack(side=tk.BOTTOM, padx=5, pady=5, fill=tk.X)
|
||||||
return optsframe
|
return optsframe
|
||||||
|
|
||||||
def add_options_info(self):
|
def add_options_info(self):
|
||||||
""" Add the info bar """
|
""" Add the info bar """
|
||||||
|
logger.debug("Adding options info")
|
||||||
lblinfo = ttk.Label(self.optsframe,
|
lblinfo = ttk.Label(self.optsframe,
|
||||||
textvariable=self.vars["info"],
|
textvariable=self.vars["info"],
|
||||||
anchor=tk.W,
|
anchor=tk.W,
|
||||||
|
@ -64,22 +73,26 @@ class DisplayPage(ttk.Frame):
|
||||||
|
|
||||||
def set_info(self, msg):
|
def set_info(self, msg):
|
||||||
""" Set the info message """
|
""" Set the info message """
|
||||||
|
logger.debug("Setting info: %s", msg)
|
||||||
self.vars["info"].set(msg)
|
self.vars["info"].set(msg)
|
||||||
|
|
||||||
def add_frame_separator(self):
|
def add_frame_separator(self):
|
||||||
""" Add a separator between top and bottom frames """
|
""" Add a separator between top and bottom frames """
|
||||||
|
logger.debug("Adding frame seperator")
|
||||||
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
|
sep = ttk.Frame(self, height=2, relief=tk.RIDGE)
|
||||||
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
|
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_mainframe_single_tab_style():
|
def set_mainframe_single_tab_style():
|
||||||
""" Configure ttk notebook style to represent a single frame """
|
""" Configure ttk notebook style to represent a single frame """
|
||||||
|
logger.debug("Setting main frame single tab style")
|
||||||
nbstyle = ttk.Style()
|
nbstyle = ttk.Style()
|
||||||
nbstyle.configure("single.TNotebook", borderwidth=0)
|
nbstyle.configure("single.TNotebook", borderwidth=0)
|
||||||
nbstyle.layout("single.TNotebook.Tab", [])
|
nbstyle.layout("single.TNotebook.Tab", [])
|
||||||
|
|
||||||
def subnotebook_add_page(self, tabtitle, widget=None):
|
def subnotebook_add_page(self, tabtitle, widget=None):
|
||||||
""" Add a page to the sub notebook """
|
""" Add a page to the sub notebook """
|
||||||
|
logger.debug("Adding subnotebook page: %s", tabtitle)
|
||||||
frame = widget if widget else ttk.Frame(self.subnotebook)
|
frame = widget if widget else ttk.Frame(self.subnotebook)
|
||||||
frame.pack(padx=5, pady=5, fill=tk.BOTH, expand=True)
|
frame.pack(padx=5, pady=5, fill=tk.BOTH, expand=True)
|
||||||
self.subnotebook.add(frame, text=tabtitle)
|
self.subnotebook.add(frame, text=tabtitle)
|
||||||
|
@ -89,28 +102,32 @@ class DisplayPage(ttk.Frame):
|
||||||
def subnotebook_configure(self):
|
def subnotebook_configure(self):
|
||||||
""" Configure notebook to display or hide tabs """
|
""" Configure notebook to display or hide tabs """
|
||||||
if len(self.subnotebook.children) == 1:
|
if len(self.subnotebook.children) == 1:
|
||||||
|
logger.debug("Setting single page style")
|
||||||
self.subnotebook.configure(style="single.TNotebook")
|
self.subnotebook.configure(style="single.TNotebook")
|
||||||
else:
|
else:
|
||||||
|
logger.debug("Setting multi page style")
|
||||||
self.subnotebook.configure(style="TNotebook")
|
self.subnotebook.configure(style="TNotebook")
|
||||||
|
|
||||||
def subnotebook_hide(self):
|
def subnotebook_hide(self):
|
||||||
""" Hide the subnotebook. Used for hiding
|
""" Hide the subnotebook. Used for hiding
|
||||||
Optional displays """
|
Optional displays """
|
||||||
if self.subnotebook.winfo_ismapped():
|
if self.subnotebook and self.subnotebook.winfo_ismapped():
|
||||||
|
logger.debug("Hiding subnotebook")
|
||||||
self.subnotebook.pack_forget()
|
self.subnotebook.pack_forget()
|
||||||
|
self.subnotebook.destroy()
|
||||||
|
self.subnotebook = None
|
||||||
|
|
||||||
def subnotebook_show(self):
|
def subnotebook_show(self):
|
||||||
""" Show subnotebook. Used for displaying
|
""" Show subnotebook. Used for displaying
|
||||||
Optional displays """
|
Optional displays """
|
||||||
if not self.subnotebook.winfo_ismapped():
|
if not self.subnotebook:
|
||||||
self.subnotebook.pack(side=tk.TOP,
|
logger.debug("Showing subnotebook")
|
||||||
anchor=tk.NW,
|
self.subnotebook = self.add_subnotebook()
|
||||||
fill=tk.BOTH,
|
|
||||||
expand=True)
|
|
||||||
|
|
||||||
def subnotebook_get_widgets(self):
|
def subnotebook_get_widgets(self):
|
||||||
""" Return each widget that sits within each
|
""" Return each widget that sits within each
|
||||||
subnotebook frame """
|
subnotebook frame """
|
||||||
|
logger.debug("Getting subnotebook widgets")
|
||||||
for child in self.subnotebook.winfo_children():
|
for child in self.subnotebook.winfo_children():
|
||||||
for widget in child.winfo_children():
|
for widget in child.winfo_children():
|
||||||
yield widget
|
yield widget
|
||||||
|
@ -120,11 +137,13 @@ class DisplayPage(ttk.Frame):
|
||||||
tabs = dict()
|
tabs = dict()
|
||||||
for tab_id in range(0, self.subnotebook.index("end")):
|
for tab_id in range(0, self.subnotebook.index("end")):
|
||||||
tabs[self.subnotebook.tab(tab_id, "text")] = tab_id
|
tabs[self.subnotebook.tab(tab_id, "text")] = tab_id
|
||||||
|
logger.debug(tabs)
|
||||||
return tabs
|
return tabs
|
||||||
|
|
||||||
def subnotebook_page_from_id(self, tab_id):
|
def subnotebook_page_from_id(self, tab_id):
|
||||||
""" Return subnotebook tab widget from it's ID """
|
""" Return subnotebook tab widget from it's ID """
|
||||||
tab_name = self.subnotebook.tabs()[tab_id].split(".")[-1]
|
tab_name = self.subnotebook.tabs()[tab_id].split(".")[-1]
|
||||||
|
logger.debug(tab_name)
|
||||||
return self.subnotebook.children[tab_name]
|
return self.subnotebook.children[tab_name]
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,19 +174,23 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
modified = tk.DoubleVar()
|
modified = tk.DoubleVar()
|
||||||
modified.set(None)
|
modified.set(None)
|
||||||
|
|
||||||
return {"enabled": enabled,
|
tk_vars = {"enabled": enabled,
|
||||||
"ready": ready,
|
"ready": ready,
|
||||||
"modified": modified}
|
"modified": modified}
|
||||||
|
logger.debug(tk_vars)
|
||||||
|
return tk_vars
|
||||||
|
|
||||||
# INFO LABEL
|
# INFO LABEL
|
||||||
def set_info_text(self):
|
def set_info_text(self):
|
||||||
""" Set waiting for display text """
|
""" Set waiting for display text """
|
||||||
if not self.vars["enabled"].get():
|
if not self.vars["enabled"].get():
|
||||||
self.set_info("{} disabled".format(self.tabname.title()))
|
msg = "{} disabled".format(self.tabname.title())
|
||||||
elif self.vars["enabled"].get() and not self.vars["ready"].get():
|
elif self.vars["enabled"].get() and not self.vars["ready"].get():
|
||||||
self.set_info("Waiting for {}...".format(self.tabname))
|
msg = "Waiting for {}...".format(self.tabname)
|
||||||
else:
|
else:
|
||||||
self.set_info("Displaying {}".format(self.tabname))
|
msg = "Displaying {}".format(self.tabname)
|
||||||
|
logger.debug(msg)
|
||||||
|
self.set_info(msg)
|
||||||
|
|
||||||
# DISPLAY OPTIONS BAR
|
# DISPLAY OPTIONS BAR
|
||||||
def add_options(self):
|
def add_options(self):
|
||||||
|
@ -177,8 +200,9 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
|
|
||||||
def add_option_save(self):
|
def add_option_save(self):
|
||||||
""" Add save button to save page output to file """
|
""" Add save button to save page output to file """
|
||||||
|
logger.debug("Adding save option")
|
||||||
btnsave = ttk.Button(self.optsframe,
|
btnsave = ttk.Button(self.optsframe,
|
||||||
image=Images().icons["save"],
|
image=get_images().icons["save"],
|
||||||
command=self.save_items)
|
command=self.save_items)
|
||||||
btnsave.pack(padx=2, side=tk.RIGHT)
|
btnsave.pack(padx=2, side=tk.RIGHT)
|
||||||
Tooltip(btnsave,
|
Tooltip(btnsave,
|
||||||
|
@ -187,6 +211,7 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
|
|
||||||
def add_option_enable(self):
|
def add_option_enable(self):
|
||||||
""" Add checkbutton to enable/disable page """
|
""" Add checkbutton to enable/disable page """
|
||||||
|
logger.debug("Adding enable option")
|
||||||
chkenable = ttk.Checkbutton(self.optsframe,
|
chkenable = ttk.Checkbutton(self.optsframe,
|
||||||
variable=self.vars["enabled"],
|
variable=self.vars["enabled"],
|
||||||
text="Enable {}".format(self.tabname),
|
text="Enable {}".format(self.tabname),
|
||||||
|
@ -202,6 +227,7 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
|
|
||||||
def on_chkenable_change(self):
|
def on_chkenable_change(self):
|
||||||
""" Update the display immediately on a checkbutton change """
|
""" Update the display immediately on a checkbutton change """
|
||||||
|
logger.debug("Enabled checkbox changed")
|
||||||
if self.vars["enabled"].get():
|
if self.vars["enabled"].get():
|
||||||
self.subnotebook_show()
|
self.subnotebook_show()
|
||||||
else:
|
else:
|
||||||
|
@ -213,6 +239,7 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
if not self.runningtask.get():
|
if not self.runningtask.get():
|
||||||
return
|
return
|
||||||
if self.vars["enabled"].get():
|
if self.vars["enabled"].get():
|
||||||
|
logger.trace("Updating page")
|
||||||
self.display_item_set()
|
self.display_item_set()
|
||||||
self.load_display()
|
self.load_display()
|
||||||
self.after(waittime, lambda t=waittime: self.update_page(t))
|
self.after(waittime, lambda t=waittime: self.update_page(t))
|
||||||
|
@ -225,6 +252,7 @@ class DisplayOptionalPage(DisplayPage):
|
||||||
""" Load the display """
|
""" Load the display """
|
||||||
if not self.display_item:
|
if not self.display_item:
|
||||||
return
|
return
|
||||||
|
logger.debug("Loading display")
|
||||||
self.display_item_process()
|
self.display_item_process()
|
||||||
self.vars["ready"].set(True)
|
self.vars["ready"].set(True)
|
||||||
self.set_info_text()
|
self.set_info_text()
|
||||||
|
|
134
lib/gui/menu.py
Normal file
134
lib/gui/menu.py
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
#!/usr/bin python3
|
||||||
|
""" The Menu Bars for faceswap GUI """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tkinter as tk
|
||||||
|
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from lib.Serializer import JSONSerializer
|
||||||
|
|
||||||
|
from .utils import get_config
|
||||||
|
from .popup_configure import popup_config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class MainMenuBar(tk.Menu):
|
||||||
|
""" GUI Main Menu Bar """
|
||||||
|
def __init__(self, master=None):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
super().__init__(master)
|
||||||
|
self.root = master
|
||||||
|
self.config = get_config()
|
||||||
|
|
||||||
|
self.file_menu = tk.Menu(self, tearoff=0)
|
||||||
|
self.recent_menu = tk.Menu(self.file_menu, tearoff=0, postcommand=self.refresh_recent_menu)
|
||||||
|
self.edit_menu = tk.Menu(self, tearoff=0)
|
||||||
|
|
||||||
|
self.build_file_menu()
|
||||||
|
self.build_edit_menu()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def build_file_menu(self):
|
||||||
|
""" Add the file menu to the menu bar """
|
||||||
|
logger.debug("Building File menu")
|
||||||
|
self.file_menu.add_command(
|
||||||
|
label="Load full config...", underline=0, command=self.config.load)
|
||||||
|
self.file_menu.add_command(
|
||||||
|
label="Save full config...", underline=0, command=self.config.save)
|
||||||
|
self.file_menu.add_separator()
|
||||||
|
self.file_menu.add_cascade(label="Open recent", underline=6, menu=self.recent_menu)
|
||||||
|
self.file_menu.add_separator()
|
||||||
|
self.file_menu.add_command(
|
||||||
|
label="Reset all to default", underline=0, command=self.config.cli_opts.reset)
|
||||||
|
self.file_menu.add_command(
|
||||||
|
label="Clear all", underline=0, command=self.config.cli_opts.clear)
|
||||||
|
self.file_menu.add_separator()
|
||||||
|
self.file_menu.add_command(label="Quit", underline=0, command=self.root.close_app)
|
||||||
|
self.add_cascade(label="File", menu=self.file_menu, underline=0)
|
||||||
|
logger.debug("Built File menu")
|
||||||
|
|
||||||
|
def build_recent_menu(self):
|
||||||
|
""" Load recent files into menu bar """
|
||||||
|
logger.debug("Building Recent Files menu")
|
||||||
|
serializer = JSONSerializer
|
||||||
|
menu_file = os.path.join(self.config.pathcache, ".recent.json")
|
||||||
|
if not os.path.isfile(menu_file):
|
||||||
|
self.clear_recent_files(serializer, menu_file)
|
||||||
|
with open(menu_file, "rb") as inp:
|
||||||
|
recent_files = serializer.unmarshal(inp.read().decode("utf-8"))
|
||||||
|
logger.debug("Loaded recent files: %s", recent_files)
|
||||||
|
for recent_item in recent_files:
|
||||||
|
filename, command = recent_item
|
||||||
|
logger.debug("processing: ('%s', %s)", filename, command)
|
||||||
|
if not os.path.isfile(filename):
|
||||||
|
logger.debug("File does not exist")
|
||||||
|
continue
|
||||||
|
lbl_command = command if command else "All"
|
||||||
|
self.recent_menu.add_command(
|
||||||
|
label="{} ({})".format(filename, lbl_command.title()),
|
||||||
|
command=lambda fnm=filename, cmd=command: self.config.load(cmd, fnm))
|
||||||
|
self.recent_menu.add_separator()
|
||||||
|
self.recent_menu.add_command(
|
||||||
|
label="Clear recent files",
|
||||||
|
underline=0,
|
||||||
|
command=lambda srl=serializer, mnu=menu_file: self.clear_recent_files(srl, mnu))
|
||||||
|
|
||||||
|
logger.debug("Built Recent Files menu")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_recent_files(serializer, menu_file):
|
||||||
|
""" Creates or clears recent file list """
|
||||||
|
logger.debug("clearing recent files list: '%s'", menu_file)
|
||||||
|
recent_files = serializer.marshal(list())
|
||||||
|
with open(menu_file, "wb") as out:
|
||||||
|
out.write(recent_files.encode("utf-8"))
|
||||||
|
|
||||||
|
def refresh_recent_menu(self):
|
||||||
|
""" Refresh recent menu on save/load of files """
|
||||||
|
self.recent_menu.delete(0, "end")
|
||||||
|
self.build_recent_menu()
|
||||||
|
|
||||||
|
def build_edit_menu(self):
|
||||||
|
""" Add the edit menu to the menu bar """
|
||||||
|
logger.debug("Building Edit menu")
|
||||||
|
edit_menu = tk.Menu(self, tearoff=0)
|
||||||
|
|
||||||
|
configs = self.scan_for_configs()
|
||||||
|
for name in sorted(list(configs.keys())):
|
||||||
|
label = "Configure {} Plugins...".format(name.title())
|
||||||
|
config = configs[name]
|
||||||
|
edit_menu.add_command(
|
||||||
|
label=label,
|
||||||
|
underline=10,
|
||||||
|
command=lambda conf=(name, config), root=self.root: popup_config(conf, root))
|
||||||
|
self.add_cascade(label="Edit", menu=edit_menu, underline=0)
|
||||||
|
logger.debug("Built Edit menu")
|
||||||
|
|
||||||
|
def scan_for_configs(self):
|
||||||
|
""" Scan for config.ini file locations """
|
||||||
|
root_path = os.path.abspath(os.path.dirname(sys.argv[0]))
|
||||||
|
plugins_path = os.path.join(root_path, "plugins")
|
||||||
|
logger.debug("Scanning path: '%s'", plugins_path)
|
||||||
|
configs = dict()
|
||||||
|
for dirpath, _, filenames in os.walk(plugins_path):
|
||||||
|
if "_config.py" in filenames:
|
||||||
|
plugin_type = os.path.split(dirpath)[-1]
|
||||||
|
config = self.load_config(plugin_type)
|
||||||
|
configs[plugin_type] = config
|
||||||
|
logger.debug("Configs loaded: %s", sorted(list(configs.keys())))
|
||||||
|
return configs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config(plugin_type):
|
||||||
|
""" Load the config to generate config file if it doesn't exist and get filename """
|
||||||
|
# Load config to generate default if doesn't exist
|
||||||
|
mod = ".".join(("plugins", plugin_type, "_config"))
|
||||||
|
module = import_module(mod)
|
||||||
|
config = module.Config(None)
|
||||||
|
logger.debug("Found '%s' config at '%s'", plugin_type, config.configfile)
|
||||||
|
return config
|
|
@ -1,14 +1,13 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" Cli Options and Config functions for the GUI """
|
""" Cli Options for the GUI """
|
||||||
import inspect
|
import inspect
|
||||||
from argparse import SUPPRESS
|
from argparse import SUPPRESS
|
||||||
import logging
|
import logging
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
from lib import cli
|
from lib import cli
|
||||||
from lib.Serializer import JSONSerializer
|
|
||||||
import tools.cli as ToolsCli
|
import tools.cli as ToolsCli
|
||||||
from .utils import FileHandler, Images
|
from .utils import get_images
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@ -93,8 +92,7 @@ class CliOptions():
|
||||||
logger.trace("Skipping suppressed option: %s", opt)
|
logger.trace("Skipping suppressed option: %s", opt)
|
||||||
continue
|
continue
|
||||||
ctl, sysbrowser, filetypes, action_option = self.set_control(opt)
|
ctl, sysbrowser, filetypes, action_option = self.set_control(opt)
|
||||||
opt["control_title"] = self.set_control_title(
|
opt["control_title"] = self.set_control_title(opt.get("opts", ""))
|
||||||
opt.get("opts", ""))
|
|
||||||
opt["control"] = ctl
|
opt["control"] = ctl
|
||||||
opt["filesystem_browser"] = sysbrowser
|
opt["filesystem_browser"] = sysbrowser
|
||||||
opt["filetypes"] = filetypes
|
opt["filetypes"] = filetypes
|
||||||
|
@ -126,6 +124,8 @@ class CliOptions():
|
||||||
sysbrowser, filetypes = self.set_sysbrowser(action,
|
sysbrowser, filetypes = self.set_sysbrowser(action,
|
||||||
filetypes,
|
filetypes,
|
||||||
action_option)
|
action_option)
|
||||||
|
elif option.get("min_max", None):
|
||||||
|
ctl = ttk.Scale
|
||||||
elif option.get("choices", "") != "":
|
elif option.get("choices", "") != "":
|
||||||
ctl = ttk.Combobox
|
ctl = ttk.Combobox
|
||||||
elif option.get("action", "") == "store_true":
|
elif option.get("action", "") == "store_true":
|
||||||
|
@ -226,7 +226,7 @@ class CliOptions():
|
||||||
optval = str(option.get("value", "").get())
|
optval = str(option.get("value", "").get())
|
||||||
opt = option["opts"][0]
|
opt = option["opts"][0]
|
||||||
if command in ("extract", "convert") and opt == "-o":
|
if command in ("extract", "convert") and opt == "-o":
|
||||||
Images().pathoutput = optval
|
get_images().pathoutput = optval
|
||||||
if optval in ("False", ""):
|
if optval in ("False", ""):
|
||||||
continue
|
continue
|
||||||
elif optval == "True":
|
elif optval == "True":
|
||||||
|
@ -238,59 +238,3 @@ class CliOptions():
|
||||||
else:
|
else:
|
||||||
opt = (opt, optval)
|
opt = (opt, optval)
|
||||||
yield opt
|
yield opt
|
||||||
|
|
||||||
|
|
||||||
class Config():
|
|
||||||
""" Actions for loading and saving Faceswap GUI command configurations """
|
|
||||||
|
|
||||||
def __init__(self, cli_opts, tk_vars):
|
|
||||||
logger.debug("Initializing %s", self.__class__.__name__)
|
|
||||||
self.cli_opts = cli_opts
|
|
||||||
self.serializer = JSONSerializer
|
|
||||||
self.tk_vars = tk_vars
|
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
|
||||||
|
|
||||||
def load(self, command=None):
|
|
||||||
""" Load a saved config file """
|
|
||||||
logger.debug("Loading config: (command: '%s')", command)
|
|
||||||
cfgfile = FileHandler("open", "config").retfile
|
|
||||||
if not cfgfile:
|
|
||||||
return
|
|
||||||
cfg = self.serializer.unmarshal(cfgfile.read())
|
|
||||||
opts = self.get_command_options(cfg, command) if command else cfg
|
|
||||||
for cmd, opts in opts.items():
|
|
||||||
self.set_command_args(cmd, opts)
|
|
||||||
logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile)
|
|
||||||
|
|
||||||
def get_command_options(self, cfg, command):
|
|
||||||
""" return the saved options for the requested
|
|
||||||
command, if not loading global options """
|
|
||||||
opts = cfg.get(command, None)
|
|
||||||
if not opts:
|
|
||||||
self.tk_vars["consoleclear"].set(True)
|
|
||||||
print("No {} section found in file".format(command))
|
|
||||||
logger.info("No %s section found in file", command)
|
|
||||||
retval = {command: opts}
|
|
||||||
logger.debug(retval)
|
|
||||||
return retval
|
|
||||||
|
|
||||||
def set_command_args(self, command, options):
|
|
||||||
""" Pass the saved config items back to the CliOptions """
|
|
||||||
if not options:
|
|
||||||
return
|
|
||||||
for srcopt, srcval in options.items():
|
|
||||||
optvar = self.cli_opts.get_one_option_variable(command, srcopt)
|
|
||||||
if not optvar:
|
|
||||||
continue
|
|
||||||
optvar.set(srcval)
|
|
||||||
|
|
||||||
def save(self, command=None):
|
|
||||||
""" Save the current GUI state to a config file in json format """
|
|
||||||
logger.debug("Saving config: (command: '%s')", command)
|
|
||||||
cfgfile = FileHandler("save", "config").retfile
|
|
||||||
if not cfgfile:
|
|
||||||
return
|
|
||||||
cfg = self.cli_opts.get_option_values(command)
|
|
||||||
cfgfile.write(self.serializer.marshal(cfg))
|
|
||||||
cfgfile.close()
|
|
||||||
logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile)
|
|
||||||
|
|
348
lib/gui/popup_configure.py
Normal file
348
lib/gui/popup_configure.py
Normal file
|
@ -0,0 +1,348 @@
|
||||||
|
#!/usr/bin python3
|
||||||
|
""" Configure Plugins popup of the Faceswap GUI """
|
||||||
|
|
||||||
|
from configparser import ConfigParser
|
||||||
|
import logging
|
||||||
|
import tkinter as tk
|
||||||
|
|
||||||
|
from tkinter import ttk
|
||||||
|
|
||||||
|
from .tooltip import Tooltip
|
||||||
|
from .utils import get_config, ContextMenu, set_slider_rounding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
POPUP = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def popup_config(config, root):
|
||||||
|
""" Close any open popup and open requested popup """
|
||||||
|
if POPUP:
|
||||||
|
p_key = list(POPUP.keys())[0]
|
||||||
|
logger.debug("Closing open popup: '%s'", p_key)
|
||||||
|
POPUP[p_key].destroy()
|
||||||
|
del POPUP[p_key]
|
||||||
|
window = ConfigurePlugins(config, root)
|
||||||
|
POPUP[config[0]] = window
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurePlugins(tk.Toplevel):
|
||||||
|
""" Pop up for detailed graph/stats for selected session """
|
||||||
|
def __init__(self, config, root):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
super().__init__()
|
||||||
|
name, self.config = config
|
||||||
|
self.title("{} Plugins".format(name.title()))
|
||||||
|
self.set_geometry(root)
|
||||||
|
|
||||||
|
self.page_frame = ttk.Frame(self)
|
||||||
|
self.page_frame.pack(fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
|
self.plugin_info = dict()
|
||||||
|
self.config_dict_gui = self.get_config()
|
||||||
|
self.build()
|
||||||
|
self.update()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def set_geometry(self, root):
|
||||||
|
""" Set pop-up geometry """
|
||||||
|
scaling_factor = get_config().scaling_factor
|
||||||
|
pos_x = root.winfo_x() + 80
|
||||||
|
pos_y = root.winfo_y() + 80
|
||||||
|
width = int(720 * scaling_factor)
|
||||||
|
height = int(400 * scaling_factor)
|
||||||
|
logger.debug("Pop up Geometry: %sx%s, %s+%s", width, height, pos_x, pos_y)
|
||||||
|
self.geometry("{}x{}+{}+{}".format(width, height, pos_x, pos_y))
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
""" Format config into useful format for GUI and pull default value if a value has not
|
||||||
|
been supplied """
|
||||||
|
logger.debug("Formatting Config for GUI")
|
||||||
|
conf = dict()
|
||||||
|
for section in self.config.config.sections():
|
||||||
|
self.config.section = section
|
||||||
|
category = section.split(".")[0]
|
||||||
|
options = self.config.defaults[section]
|
||||||
|
conf.setdefault(category, dict())[section] = options
|
||||||
|
for key in options.keys():
|
||||||
|
if key == "helptext":
|
||||||
|
self.plugin_info[section] = options[key]
|
||||||
|
continue
|
||||||
|
options[key]["value"] = self.config.config_dict.get(key, options[key]["default"])
|
||||||
|
logger.debug("Formatted Config for GUI: %s", conf)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
""" Build the config popup """
|
||||||
|
logger.debug("Building plugin config popup")
|
||||||
|
container = ttk.Notebook(self.page_frame)
|
||||||
|
container.pack(fill=tk.BOTH, expand=True)
|
||||||
|
categories = sorted(list(key for key in self.config_dict_gui.keys()))
|
||||||
|
if "global" in categories: # Move global to first item
|
||||||
|
categories.insert(0, categories.pop(categories.index("global")))
|
||||||
|
for category in categories:
|
||||||
|
page = self.build_page(container, category)
|
||||||
|
container.add(page, text=category.title())
|
||||||
|
|
||||||
|
self.add_frame_separator()
|
||||||
|
self.add_actions()
|
||||||
|
logger.debug("Built plugin config popup")
|
||||||
|
|
||||||
|
def build_page(self, container, category):
|
||||||
|
""" Build a plugin config page """
|
||||||
|
logger.debug("Building plugin config page: '%s'", category)
|
||||||
|
plugins = sorted(list(key for key in self.config_dict_gui[category].keys()))
|
||||||
|
if any(plugin != category for plugin in plugins):
|
||||||
|
page = ttk.Notebook(container)
|
||||||
|
page.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
for plugin in plugins:
|
||||||
|
frame = ConfigFrame(page,
|
||||||
|
self.config_dict_gui[category][plugin],
|
||||||
|
self.plugin_info[plugin])
|
||||||
|
title = plugin[plugin.rfind(".") + 1:]
|
||||||
|
title = title.replace("_", " ").title()
|
||||||
|
page.add(frame, text=title)
|
||||||
|
else:
|
||||||
|
page = ConfigFrame(container,
|
||||||
|
self.config_dict_gui[category][plugins[0]],
|
||||||
|
self.plugin_info[plugins[0]])
|
||||||
|
|
||||||
|
logger.debug("Built plugin config page: '%s'", category)
|
||||||
|
|
||||||
|
return page
|
||||||
|
|
||||||
|
def add_frame_separator(self):
|
||||||
|
""" Add a separator between top and bottom frames """
|
||||||
|
logger.debug("Add frame seperator")
|
||||||
|
sep = ttk.Frame(self.page_frame, height=2, relief=tk.RIDGE)
|
||||||
|
sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM)
|
||||||
|
logger.debug("Added frame seperator")
|
||||||
|
|
||||||
|
def add_actions(self):
|
||||||
|
""" Add Action buttons """
|
||||||
|
logger.debug("Add action buttons")
|
||||||
|
frame = ttk.Frame(self.page_frame)
|
||||||
|
frame.pack(fill=tk.BOTH, padx=5, pady=5, side=tk.BOTTOM)
|
||||||
|
btn_cls = ttk.Button(frame, text="Cancel", width=10, command=self.destroy)
|
||||||
|
btn_cls.pack(padx=2, side=tk.RIGHT)
|
||||||
|
btn_ok = ttk.Button(frame, text="OK", width=10, command=self.save_config)
|
||||||
|
btn_ok.pack(padx=2, side=tk.RIGHT)
|
||||||
|
logger.debug("Added action buttons")
|
||||||
|
|
||||||
|
def save_config(self):
|
||||||
|
""" Save the config file """
|
||||||
|
logger.debug("Saving config")
|
||||||
|
options = {sect: opts
|
||||||
|
for value in self.config_dict_gui.values()
|
||||||
|
for sect, opts in value.items()}
|
||||||
|
|
||||||
|
new_config = ConfigParser(allow_no_value=True)
|
||||||
|
for section, items in self.config.defaults.items():
|
||||||
|
logger.debug("Adding section: '%s')", section)
|
||||||
|
self.config.insert_config_section(section, items["helptext"], config=new_config)
|
||||||
|
for item, def_opt in items.items():
|
||||||
|
if item == "helptext":
|
||||||
|
continue
|
||||||
|
new_opt = options[section][item]
|
||||||
|
logger.debug("Adding option: (item: '%s', default: '%s' new: '%s'",
|
||||||
|
item, def_opt, new_opt)
|
||||||
|
helptext = def_opt["helptext"]
|
||||||
|
helptext += self.config.set_helptext_choices(def_opt)
|
||||||
|
helptext += "\n[Default: {}]".format(def_opt["default"])
|
||||||
|
helptext = self.config.format_help(helptext, is_section=False)
|
||||||
|
new_config.set(section, helptext)
|
||||||
|
new_config.set(section, item, str(new_opt["selected"].get()))
|
||||||
|
self.config.config = new_config
|
||||||
|
self.config.save_config()
|
||||||
|
print("Saved config: '{}'".format(self.config.configfile))
|
||||||
|
self.destroy()
|
||||||
|
logger.debug("Saved config")
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
|
""" Config Frame - Holds the Options for config """
|
||||||
|
|
||||||
|
def __init__(self, parent, options, plugin_info):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
ttk.Frame.__init__(self, parent)
|
||||||
|
self.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
|
self.options = options
|
||||||
|
self.plugin_info = plugin_info
|
||||||
|
|
||||||
|
self.canvas = tk.Canvas(self, bd=0, highlightthickness=0)
|
||||||
|
self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
|
self.optsframe = ttk.Frame(self.canvas)
|
||||||
|
self.optscanvas = self.canvas.create_window((0, 0), window=self.optsframe, anchor=tk.NW)
|
||||||
|
|
||||||
|
self.build_frame()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def build_frame(self):
|
||||||
|
""" Build the options frame for this command """
|
||||||
|
logger.debug("Add Config Frame")
|
||||||
|
self.add_scrollbar()
|
||||||
|
self.canvas.bind("<Configure>", self.resize_frame)
|
||||||
|
|
||||||
|
self.add_info()
|
||||||
|
for key, val in self.options.items():
|
||||||
|
if key == "helptext":
|
||||||
|
continue
|
||||||
|
OptionControl(key, val, self.optsframe)
|
||||||
|
logger.debug("Added Config Frame")
|
||||||
|
|
||||||
|
def add_scrollbar(self):
|
||||||
|
""" Add a scrollbar to the options frame """
|
||||||
|
logger.debug("Add Config Scrollbar")
|
||||||
|
scrollbar = ttk.Scrollbar(self, command=self.canvas.yview)
|
||||||
|
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||||
|
self.canvas.config(yscrollcommand=scrollbar.set)
|
||||||
|
self.optsframe.bind("<Configure>", self.update_scrollbar)
|
||||||
|
logger.debug("Added Config Scrollbar")
|
||||||
|
|
||||||
|
def update_scrollbar(self, event): # pylint: disable=unused-argument
|
||||||
|
""" Update the options frame scrollbar """
|
||||||
|
self.canvas.configure(scrollregion=self.canvas.bbox("all"))
|
||||||
|
|
||||||
|
def resize_frame(self, event):
|
||||||
|
""" Resize the options frame to fit the canvas """
|
||||||
|
logger.debug("Resize Config Frame")
|
||||||
|
canvas_width = event.width
|
||||||
|
self.canvas.itemconfig(self.optscanvas, width=canvas_width)
|
||||||
|
logger.debug("Resized Config Frame")
|
||||||
|
|
||||||
|
def add_info(self):
|
||||||
|
""" Plugin information """
|
||||||
|
info_frame = ttk.Frame(self.optsframe)
|
||||||
|
info_frame.pack(fill=tk.X, expand=True)
|
||||||
|
lbl = ttk.Label(info_frame, text="About:", width=20, anchor=tk.W)
|
||||||
|
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
|
||||||
|
info = ttk.Label(info_frame, text=self.plugin_info)
|
||||||
|
info.pack(padx=5, pady=5, fill=tk.X, expand=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OptionControl():
|
||||||
|
""" Build the correct control for the option parsed and place it on the
|
||||||
|
frame """
|
||||||
|
|
||||||
|
def __init__(self, title, values, option_frame):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
self.title = title
|
||||||
|
self.values = values
|
||||||
|
self.option_frame = option_frame
|
||||||
|
|
||||||
|
self.control = self.set_control()
|
||||||
|
self.control_frame = self.set_control_frame()
|
||||||
|
self.tk_var = self.set_tk_var()
|
||||||
|
|
||||||
|
self.build_full_control()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def helptext(self):
|
||||||
|
""" Format the help text for tooltips """
|
||||||
|
logger.debug("Format control help: '%s'", self.title)
|
||||||
|
helptext = self.values.get("helptext", "")
|
||||||
|
helptext = helptext.replace("\n\t", "\n - ").replace("%%", "%")
|
||||||
|
helptext = self.title + " - " + helptext
|
||||||
|
logger.debug("Formatted control help: (title: '%s', help: '%s'", self.title, helptext)
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
def set_control(self):
|
||||||
|
""" Set the correct control type for this option """
|
||||||
|
dtype = self.values["type"]
|
||||||
|
choices = self.values["choices"]
|
||||||
|
if choices:
|
||||||
|
control = ttk.Combobox
|
||||||
|
elif dtype == bool:
|
||||||
|
control = ttk.Checkbutton
|
||||||
|
elif dtype in (int, float):
|
||||||
|
control = ttk.Scale
|
||||||
|
else:
|
||||||
|
control = ttk.Entry
|
||||||
|
logger.debug("Setting control '%s' to %s", self.title, control)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def set_control_frame(self):
|
||||||
|
""" Frame to hold control and it's label """
|
||||||
|
logger.debug("Build config control frame")
|
||||||
|
frame = ttk.Frame(self.option_frame)
|
||||||
|
frame.pack(fill=tk.X, expand=True)
|
||||||
|
logger.debug("Built confog control frame")
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def set_tk_var(self):
|
||||||
|
""" Correct variable type for control """
|
||||||
|
logger.debug("Setting config variable type: '%s'", self.title)
|
||||||
|
var = tk.BooleanVar if self.control == ttk.Checkbutton else tk.StringVar
|
||||||
|
var = var(self.control_frame)
|
||||||
|
logger.debug("Set config variable type: ('%s': %s", self.title, type(var))
|
||||||
|
return var
|
||||||
|
|
||||||
|
def build_full_control(self):
|
||||||
|
""" Build the correct control type for the option passed through """
|
||||||
|
logger.debug("Build confog option control")
|
||||||
|
self.build_control_label()
|
||||||
|
self.build_one_control()
|
||||||
|
self.values["selected"] = self.tk_var
|
||||||
|
logger.debug("Built option control")
|
||||||
|
|
||||||
|
def build_control_label(self):
|
||||||
|
""" Label for control """
|
||||||
|
logger.debug("Build config control label: '%s'", self.title)
|
||||||
|
title = self.title.replace("_", " ").title()
|
||||||
|
lbl = ttk.Label(self.control_frame, text=title, width=20, anchor=tk.W)
|
||||||
|
lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N)
|
||||||
|
logger.debug("Built config control label: '%s'", self.title)
|
||||||
|
|
||||||
|
def build_one_control(self):
|
||||||
|
""" Build and place the option controls """
|
||||||
|
logger.debug("Build control: (title: '%s', values: %s)", self.title, self.values)
|
||||||
|
self.tk_var.set(self.values["value"])
|
||||||
|
|
||||||
|
if self.control == ttk.Scale:
|
||||||
|
self.slider_control()
|
||||||
|
else:
|
||||||
|
self.control_to_optionsframe()
|
||||||
|
logger.debug("Built control: '%s'", self.title)
|
||||||
|
|
||||||
|
def slider_control(self):
|
||||||
|
""" A slider control with corresponding Entry box """
|
||||||
|
logger.debug("Add slider control to Config Options Frame: %s", self.control)
|
||||||
|
d_type = self.values["type"]
|
||||||
|
rnd = self.values["rounding"]
|
||||||
|
min_max = self.values["min_max"]
|
||||||
|
|
||||||
|
tbox = ttk.Entry(self.control_frame, width=8, textvariable=self.tk_var, justify=tk.RIGHT)
|
||||||
|
tbox.pack(padx=(0, 5), side=tk.RIGHT)
|
||||||
|
ctl = self.control(
|
||||||
|
self.control_frame,
|
||||||
|
variable=self.tk_var,
|
||||||
|
command=lambda val, var=self.tk_var, dt=d_type, rn=rnd, mm=min_max:
|
||||||
|
set_slider_rounding(val, var, dt, rn, mm))
|
||||||
|
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
|
||||||
|
rc_menu = ContextMenu(ctl)
|
||||||
|
rc_menu.cm_bind()
|
||||||
|
ctl["from_"] = min_max[0]
|
||||||
|
ctl["to"] = min_max[1]
|
||||||
|
|
||||||
|
Tooltip(ctl, text=self.helptext, wraplength=720)
|
||||||
|
Tooltip(tbox, text=self.helptext, wraplength=720)
|
||||||
|
logger.debug("Added slider control to Options Frame: %s", self.control)
|
||||||
|
|
||||||
|
def control_to_optionsframe(self):
|
||||||
|
""" Standard non-check buttons sit in the main options frame """
|
||||||
|
logger.debug("Add control to Options Frame: %s", self.control)
|
||||||
|
choices = self.values["choices"]
|
||||||
|
if self.control == ttk.Checkbutton:
|
||||||
|
ctl = self.control(self.control_frame, variable=self.tk_var, text=None)
|
||||||
|
else:
|
||||||
|
ctl = self.control(self.control_frame, textvariable=self.tk_var)
|
||||||
|
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
|
||||||
|
rc_menu = ContextMenu(ctl)
|
||||||
|
rc_menu.cm_bind()
|
||||||
|
if choices:
|
||||||
|
logger.debug("Adding combo choices: %s", choices)
|
||||||
|
ctl["values"] = [choice for choice in choices]
|
||||||
|
Tooltip(ctl, text=self.helptext, wraplength=720)
|
||||||
|
logger.debug("Added control to Options Frame: %s", self.control)
|
560
lib/gui/stats.py
560
lib/gui/stats.py
|
@ -9,14 +9,14 @@ import warnings
|
||||||
from math import ceil, sqrt
|
from math import ceil, sqrt
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
from lib.Serializer import PickleSerializer
|
from lib.Serializer import JSONSerializer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def convert_time(timestamp):
|
def convert_time(timestamp):
|
||||||
""" Convert time stamp to total hours, mins and second """
|
""" Convert time stamp to total hours, minutes and seconds """
|
||||||
hrs = int(timestamp // 3600)
|
hrs = int(timestamp // 3600)
|
||||||
if hrs < 10:
|
if hrs < 10:
|
||||||
hrs = "{0:02d}".format(hrs)
|
hrs = "{0:02d}".format(hrs)
|
||||||
|
@ -25,164 +25,279 @@ def convert_time(timestamp):
|
||||||
return hrs, mins, secs
|
return hrs, mins, secs
|
||||||
|
|
||||||
|
|
||||||
class SavedSessions():
|
class TensorBoardLogs():
|
||||||
""" Saved Training Session """
|
""" Parse and return data from TensorBoard logs """
|
||||||
def __init__(self, sessions_data):
|
def __init__(self, logs_folder):
|
||||||
self.serializer = PickleSerializer
|
self.folder_base = logs_folder
|
||||||
self.sessions = self.load_sessions(sessions_data)
|
self.log_filenames = self.set_log_filenames()
|
||||||
|
|
||||||
def load_sessions(self, filename):
|
def set_log_filenames(self):
|
||||||
""" Load previously saved sessions """
|
""" Set the TensorBoard log filenames for all existing sessions """
|
||||||
stats = list()
|
logger.debug("Loading log filenames. base_dir: '%s'", self.folder_base)
|
||||||
if os.path.isfile(filename):
|
log_filenames = dict()
|
||||||
with open(filename, self.serializer.roptions) as sessions:
|
for dirpath, _, filenames in os.walk(self.folder_base):
|
||||||
stats = self.serializer.unmarshal(sessions.read())
|
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
|
||||||
return stats
|
continue
|
||||||
|
logfiles = [filename for filename in filenames
|
||||||
|
if filename.startswith("events.out.tfevents")]
|
||||||
|
# Take the last logfile, in case of previous crash
|
||||||
|
logfile = os.path.join(dirpath, sorted(logfiles)[-1])
|
||||||
|
side, session = os.path.split(dirpath)
|
||||||
|
side = os.path.split(side)[1]
|
||||||
|
session = int(session[session.rfind("_") + 1:])
|
||||||
|
log_filenames.setdefault(session, dict())[side] = logfile
|
||||||
|
logger.debug("logfiles: %s", log_filenames)
|
||||||
|
return log_filenames
|
||||||
|
|
||||||
def save_sessions(self, filename):
|
def get_loss(self, side=None, session=None):
|
||||||
""" Save the session file """
|
""" Read the loss from the TensorBoard logs
|
||||||
with open(filename, self.serializer.woptions) as session:
|
Specify a side or a session or leave at None for all
|
||||||
session.write(self.serializer.marshal(self.sessions))
|
"""
|
||||||
logger.info("Saved session stats to: %s", filename)
|
logger.debug("Getting loss: (side: %s, session: %s)", side, session)
|
||||||
|
all_loss = dict()
|
||||||
|
for sess, sides in self.log_filenames.items():
|
||||||
|
if session is not None and sess != session:
|
||||||
|
logger.debug("Skipping session: %s", sess)
|
||||||
|
continue
|
||||||
|
loss = dict()
|
||||||
|
for sde, logfile in sides.items():
|
||||||
|
if side is not None and sde != side:
|
||||||
|
logger.debug("Skipping side: %s", sde)
|
||||||
|
continue
|
||||||
|
for event in tf.train.summary_iterator(logfile):
|
||||||
|
for summary in event.summary.value:
|
||||||
|
if "loss" not in summary.tag:
|
||||||
|
continue
|
||||||
|
tag = summary.tag.replace("batch_", "")
|
||||||
|
loss.setdefault(tag,
|
||||||
|
dict()).setdefault(sde,
|
||||||
|
list()).append(summary.simple_value)
|
||||||
|
all_loss[sess] = loss
|
||||||
|
return all_loss
|
||||||
|
|
||||||
|
def get_timestamps(self, session=None):
|
||||||
|
""" Read the timestamps from the TensorBoard logs
|
||||||
|
Specify a session or leave at None for all
|
||||||
|
NB: For all intents and purposes timestamps are the same for
|
||||||
|
both sides, so just read from one side """
|
||||||
|
logger.debug("Getting timestamps")
|
||||||
|
all_timestamps = dict()
|
||||||
|
for sess, sides in self.log_filenames.items():
|
||||||
|
if session is not None and sess != session:
|
||||||
|
logger.debug("Skipping sessions: %s", sess)
|
||||||
|
continue
|
||||||
|
for logfile in sides.values():
|
||||||
|
timestamps = [event.wall_time
|
||||||
|
for event in tf.train.summary_iterator(logfile)]
|
||||||
|
logger.debug("Total timestamps for session %s: %s", sess, len(timestamps))
|
||||||
|
all_timestamps[sess] = timestamps
|
||||||
|
break # break after first file read
|
||||||
|
return all_timestamps
|
||||||
|
|
||||||
|
|
||||||
class CurrentSession():
|
class Session():
|
||||||
""" The current training session """
|
""" The Loaded or current training session """
|
||||||
def __init__(self):
|
def __init__(self, model_dir=None, model_name=None):
|
||||||
self.stats = {"iterations": 0,
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
"batchsize": None, # Set and reset by wrapper
|
self.serializer = JSONSerializer
|
||||||
"timestamps": [],
|
self.state = None
|
||||||
"loss": [],
|
self.modeldir = model_dir # Set and reset by wrapper for training sessions
|
||||||
"losskeys": []}
|
self.modelname = model_name # Set and reset by wrapper for training sessions
|
||||||
self.timestats = {"start": None,
|
self.tb_logs = None
|
||||||
"elapsed": None}
|
self.initialized = False
|
||||||
self.modeldir = None # Set and reset by wrapper
|
self.session_id = None # Set to specific session_id or current training session
|
||||||
self.filename = None
|
self.summary = SessionsSummary(self)
|
||||||
self.historical = None
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def initialise_session(self, currentloss):
|
@property
|
||||||
""" Initialise the training session """
|
def batchsize(self):
|
||||||
self.load_historical()
|
""" Return the session batchsize """
|
||||||
for item in currentloss:
|
return self.session["batchsize"]
|
||||||
self.stats["losskeys"].append(item[0])
|
|
||||||
self.stats["loss"].append(list())
|
|
||||||
self.timestats["start"] = time.time()
|
|
||||||
|
|
||||||
def load_historical(self):
|
@property
|
||||||
""" Load historical data and add current session to the end """
|
def config(self):
|
||||||
self.filename = os.path.join(self.modeldir, "trainingstats.fss")
|
""" Return config and other information """
|
||||||
self.historical = SavedSessions(self.filename)
|
retval = {key: val for key, val in self.state["config"]}
|
||||||
self.historical.sessions.append(self.stats)
|
retval["training_size"] = self.state["training_size"]
|
||||||
|
retval["input_size"] = [val[0] for key, val in self.state["inputs"].items()
|
||||||
|
if key.startswith("face")][0]
|
||||||
|
return retval
|
||||||
|
|
||||||
def add_loss(self, currentloss):
|
@property
|
||||||
""" Add a loss item from the training process """
|
def full_summary(self):
|
||||||
if self.stats["iterations"] == 0:
|
""" Retun all sessions summary data"""
|
||||||
self.initialise_session(currentloss)
|
return self.summary.compile_stats()
|
||||||
|
|
||||||
self.stats["iterations"] += 1
|
@property
|
||||||
self.add_timestats()
|
def iterations(self):
|
||||||
|
""" Return session iterations """
|
||||||
|
return self.session["iterations"]
|
||||||
|
|
||||||
for idx, item in enumerate(currentloss):
|
@property
|
||||||
self.stats["loss"][idx].append(float(item[1]))
|
def logging_disabled(self):
|
||||||
|
""" Return whether logging is disabled for this session """
|
||||||
|
return self.session["no_logs"]
|
||||||
|
|
||||||
def add_timestats(self):
|
@property
|
||||||
""" Add timestats to loss dict and timestats """
|
def loss(self):
|
||||||
now = time.time()
|
""" Return loss from logs for current session """
|
||||||
self.stats["timestamps"].append(now)
|
loss_dict = self.tb_logs.get_loss(session=self.session_id)[self.session_id]
|
||||||
elapsed_time = now - self.timestats["start"]
|
return loss_dict
|
||||||
hrs, mins, secs = convert_time(elapsed_time)
|
|
||||||
self.timestats["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
|
|
||||||
|
|
||||||
def save_session(self):
|
@property
|
||||||
""" Save the session file to the modeldir """
|
def loss_keys(self):
|
||||||
if self.stats["iterations"] > 0:
|
""" Return list of unique session loss keys """
|
||||||
logger.info("Saving session stats...")
|
if self.session_id is None:
|
||||||
self.historical.save_sessions(self.filename)
|
loss_keys = self.total_loss_keys
|
||||||
|
else:
|
||||||
|
loss_keys = set(loss_key for side_keys in self.session["loss_names"].values()
|
||||||
|
for loss_key in side_keys)
|
||||||
|
return list(loss_keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lowest_loss(self):
|
||||||
|
""" Return the lowest average loss per save iteration seen """
|
||||||
|
return self.state["lowest_avg_loss"]
|
||||||
|
|
||||||
class SessionsTotals():
|
@property
|
||||||
""" The compiled totals of all saved sessions """
|
def session(self):
|
||||||
def __init__(self, all_sessions):
|
""" Return current session dictionary """
|
||||||
self.stats = {"split": [],
|
return self.state["sessions"][str(self.session_id)]
|
||||||
"iterations": 0,
|
|
||||||
"batchsize": [],
|
|
||||||
"timestamps": [],
|
|
||||||
"loss": [],
|
|
||||||
"losskeys": []}
|
|
||||||
|
|
||||||
self.initiate(all_sessions)
|
@property
|
||||||
self.compile(all_sessions)
|
def session_ids(self):
|
||||||
|
""" Return sorted list of all existing session ids in the state file """
|
||||||
|
return sorted([int(key) for key in self.state["sessions"].keys()])
|
||||||
|
|
||||||
def initiate(self, sessions):
|
@property
|
||||||
""" Initiate correct losskey titles and number of loss lists """
|
def timestamps(self):
|
||||||
for losskey in sessions[0]["losskeys"]:
|
""" Return timestamps from logs for current session """
|
||||||
self.stats["losskeys"].append(losskey)
|
ts_dict = self.tb_logs.get_timestamps(session=self.session_id)
|
||||||
self.stats["loss"].append(list())
|
return ts_dict[self.session_id]
|
||||||
|
|
||||||
def compile(self, sessions):
|
@property
|
||||||
""" Compile all of the sessions into totals """
|
def total_batchsize(self):
|
||||||
current_split = 0
|
""" Return all session batch sizes """
|
||||||
for session in sessions:
|
return {int(sess_id): sess["batchsize"]
|
||||||
iterations = session["iterations"]
|
for sess_id, sess in self.state["sessions"].items()}
|
||||||
current_split += iterations
|
|
||||||
self.stats["split"].append(current_split)
|
|
||||||
self.stats["iterations"] += iterations
|
|
||||||
self.stats["timestamps"].extend(session["timestamps"])
|
|
||||||
self.stats["batchsize"].append(session["batchsize"])
|
|
||||||
self.add_loss(session["loss"])
|
|
||||||
|
|
||||||
def add_loss(self, session_loss):
|
@property
|
||||||
""" Add loss vals to each of their respective lists """
|
def total_iterations(self):
|
||||||
for idx, loss in enumerate(session_loss):
|
""" Return session iterations """
|
||||||
self.stats["loss"][idx].extend(loss)
|
return self.state["iterations"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_loss(self):
|
||||||
|
""" Return collated loss for all session """
|
||||||
|
loss_dict = dict()
|
||||||
|
for sess in self.tb_logs.get_loss().values():
|
||||||
|
for loss_key, side_loss in sess.items():
|
||||||
|
for side, loss in side_loss.items():
|
||||||
|
loss_dict.setdefault(loss_key, dict()).setdefault(side, list()).extend(loss)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_loss_keys(self):
|
||||||
|
""" Return list of unique session loss keys across all sessions """
|
||||||
|
loss_keys = set(loss_key
|
||||||
|
for session in self.state["sessions"].values()
|
||||||
|
for loss_keys in session["loss_names"].values()
|
||||||
|
for loss_key in loss_keys)
|
||||||
|
return list(loss_keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_timestamps(self):
|
||||||
|
""" Return timestamps from logs seperated per session for all sessions """
|
||||||
|
return self.tb_logs.get_timestamps()
|
||||||
|
|
||||||
|
def initialize_session(self, is_training=False, session_id=None):
|
||||||
|
""" Initialize the training session """
|
||||||
|
logger.debug("Initializing session: (is_training: %s, session_id: %s)",
|
||||||
|
is_training, session_id)
|
||||||
|
self.load_state_file()
|
||||||
|
self.tb_logs = TensorBoardLogs(os.path.join(self.modeldir,
|
||||||
|
"{}_logs".format(self.modelname)))
|
||||||
|
if is_training:
|
||||||
|
self.session_id = max(int(key) for key in self.state["sessions"].keys())
|
||||||
|
else:
|
||||||
|
self.session_id = session_id
|
||||||
|
self.initialized = True
|
||||||
|
logger.debug("Initialized session")
|
||||||
|
|
||||||
|
def load_state_file(self):
|
||||||
|
""" Load the current state file """
|
||||||
|
state_file = os.path.join(self.modeldir, "{}_state.json".format(self.modelname))
|
||||||
|
logger.debug("Loading State: '%s'", state_file)
|
||||||
|
try:
|
||||||
|
with open(state_file, "rb") as inp:
|
||||||
|
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
|
||||||
|
self.state = state
|
||||||
|
logger.debug("Loaded state: %s", state)
|
||||||
|
except IOError as err:
|
||||||
|
logger.warning("Unable to load state file. Graphing disabled: %s", str(err))
|
||||||
|
|
||||||
|
|
||||||
class SessionsSummary():
|
class SessionsSummary():
|
||||||
""" Calculations for analysis summary stats """
|
""" Calculations for analysis summary stats """
|
||||||
|
|
||||||
def __init__(self, raw_data):
|
def __init__(self, session):
|
||||||
self.summary = list()
|
logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session)
|
||||||
self.summary_stats_compile(raw_data)
|
self.session = session
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def summary_stats_compile(self, raw_data):
|
@property
|
||||||
""" Compile summary stats """
|
def iterations(self):
|
||||||
raw_summaries = list()
|
""" Return session iterations sizes """
|
||||||
for idx, session in enumerate(raw_data):
|
return {int(sess_id): sess["iterations"]
|
||||||
raw_summaries.append(self.summarise_session(idx, session))
|
for sess_id, sess in self.session.state["sessions"].items()}
|
||||||
|
|
||||||
totals_summary = self.summarise_totals(raw_summaries)
|
@property
|
||||||
raw_summaries.append(totals_summary)
|
def time_stats(self):
|
||||||
self.format_summaries(raw_summaries)
|
""" Return session time stats """
|
||||||
|
ts_data = self.session.tb_logs.get_timestamps()
|
||||||
|
time_stats = {sess_id: {"start_time": min(timestamps),
|
||||||
|
"end_time": max(timestamps)}
|
||||||
|
for sess_id, timestamps in ts_data.items()}
|
||||||
|
return time_stats
|
||||||
|
|
||||||
# Compile Session Summaries
|
@property
|
||||||
@staticmethod
|
def sessions_stats(self):
|
||||||
def summarise_session(idx, session):
|
""" Return compiled stats """
|
||||||
""" Compile stats for session passed in """
|
compiled = list()
|
||||||
starttime = session["timestamps"][0]
|
for sess_idx, ts_data in self.time_stats.items():
|
||||||
endtime = session["timestamps"][-1]
|
elapsed = ts_data["end_time"] - ts_data["start_time"]
|
||||||
elapsed = endtime - starttime
|
batchsize = self.session.total_batchsize[sess_idx]
|
||||||
# Bump elapsed to 0.1s if no time is recorded
|
iterations = self.iterations[sess_idx]
|
||||||
# to hack around div by zero error
|
compiled.append({"session": sess_idx,
|
||||||
elapsed = 0.1 if elapsed == 0 else elapsed
|
"start": ts_data["start_time"],
|
||||||
rate = (session["batchsize"] * session["iterations"]) / elapsed
|
"end": ts_data["end_time"],
|
||||||
return {"session": idx + 1,
|
"elapsed": elapsed,
|
||||||
"start": starttime,
|
"rate": (batchsize * iterations) / elapsed,
|
||||||
"end": endtime,
|
"batch": batchsize,
|
||||||
"elapsed": elapsed,
|
"iterations": iterations})
|
||||||
"rate": rate,
|
return compiled
|
||||||
"batch": session["batchsize"],
|
|
||||||
"iterations": session["iterations"]}
|
def compile_stats(self):
|
||||||
|
""" Compile sessions stats with totals, format and return """
|
||||||
|
logger.debug("Compiling sessions summary data")
|
||||||
|
compiled_stats = self.sessions_stats
|
||||||
|
logger.debug("sessions_stats: %s", compiled_stats)
|
||||||
|
total_stats = self.total_stats(compiled_stats)
|
||||||
|
compiled_stats.append(total_stats)
|
||||||
|
compiled_stats = self.format_stats(compiled_stats)
|
||||||
|
logger.debug("Final stats: %s", compiled_stats)
|
||||||
|
return compiled_stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def summarise_totals(raw_summaries):
|
def total_stats(sessions_stats):
|
||||||
""" Compile the stats for all sessions combined """
|
""" Return total stats """
|
||||||
|
logger.debug("Compiling Totals")
|
||||||
elapsed = 0
|
elapsed = 0
|
||||||
rate = 0
|
rate = 0
|
||||||
batchset = set()
|
batchset = set()
|
||||||
iterations = 0
|
iterations = 0
|
||||||
total_summaries = len(raw_summaries)
|
total_summaries = len(sessions_stats)
|
||||||
|
for idx, summary in enumerate(sessions_stats):
|
||||||
for idx, summary in enumerate(raw_summaries):
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
starttime = summary["start"]
|
starttime = summary["start"]
|
||||||
if idx == total_summaries - 1:
|
if idx == total_summaries - 1:
|
||||||
|
@ -192,150 +307,170 @@ class SessionsSummary():
|
||||||
batchset.add(summary["batch"])
|
batchset.add(summary["batch"])
|
||||||
iterations += summary["iterations"]
|
iterations += summary["iterations"]
|
||||||
batch = ",".join(str(bs) for bs in batchset)
|
batch = ",".join(str(bs) for bs in batchset)
|
||||||
|
totals = {"session": "Total",
|
||||||
|
"start": starttime,
|
||||||
|
"end": endtime,
|
||||||
|
"elapsed": elapsed,
|
||||||
|
"rate": rate / total_summaries,
|
||||||
|
"batch": batch,
|
||||||
|
"iterations": iterations}
|
||||||
|
logger.debug(totals)
|
||||||
|
return totals
|
||||||
|
|
||||||
return {"session": "Total",
|
@staticmethod
|
||||||
"start": starttime,
|
def format_stats(compiled_stats):
|
||||||
"end": endtime,
|
""" Format for display """
|
||||||
"elapsed": elapsed,
|
logger.debug("Formatting stats")
|
||||||
"rate": rate / total_summaries,
|
for summary in compiled_stats:
|
||||||
"batch": batch,
|
|
||||||
"iterations": iterations}
|
|
||||||
|
|
||||||
def format_summaries(self, raw_summaries):
|
|
||||||
""" Format the summaries nicely for display """
|
|
||||||
for summary in raw_summaries:
|
|
||||||
summary["start"] = time.strftime("%x %X",
|
|
||||||
time.gmtime(summary["start"]))
|
|
||||||
summary["end"] = time.strftime("%x %X",
|
|
||||||
time.gmtime(summary["end"]))
|
|
||||||
hrs, mins, secs = convert_time(summary["elapsed"])
|
hrs, mins, secs = convert_time(summary["elapsed"])
|
||||||
|
summary["start"] = time.strftime("%x %X", time.gmtime(summary["start"]))
|
||||||
|
summary["end"] = time.strftime("%x %X", time.gmtime(summary["end"]))
|
||||||
summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
|
summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
|
||||||
summary["rate"] = "{0:.1f}".format(summary["rate"])
|
summary["rate"] = "{0:.1f}".format(summary["rate"])
|
||||||
self.summary = raw_summaries
|
return compiled_stats
|
||||||
|
|
||||||
|
|
||||||
class Calculations():
|
class Calculations():
|
||||||
""" Class to hold calculations against raw session data """
|
""" Class to pull raw data for given session(s) and perform calculations """
|
||||||
def __init__(self,
|
def __init__(self, session, display="loss", loss_keys=["loss"], selections=["raw"],
|
||||||
session,
|
avg_samples=10, flatten_outliers=False, is_totals=False):
|
||||||
display="loss",
|
logger.debug("Initializing %s: (session: %s, display: %s, loss_keys: %s, selections: %s, "
|
||||||
selections=["raw"],
|
"avg_samples: %s, flatten_outliers: %s, is_totals: %s",
|
||||||
avg_samples=10,
|
self.__class__.__name__, session, display, loss_keys, selections, avg_samples,
|
||||||
flatten_outliers=False,
|
flatten_outliers, is_totals)
|
||||||
is_totals=False):
|
|
||||||
|
|
||||||
warnings.simplefilter("ignore", np.RankWarning)
|
warnings.simplefilter("ignore", np.RankWarning)
|
||||||
|
|
||||||
self.session = session
|
self.session = session
|
||||||
if display.lower() == "loss":
|
self.display = display
|
||||||
display = self.session["losskeys"]
|
self.loss_keys = loss_keys
|
||||||
else:
|
self.selections = selections
|
||||||
display = [display]
|
self.is_totals = is_totals
|
||||||
self.args = {"display": display,
|
self.args = {"avg_samples": int(avg_samples),
|
||||||
"selections": selections,
|
"flatten_outliers": flatten_outliers}
|
||||||
"avg_samples": int(avg_samples),
|
|
||||||
"flatten_outliers": flatten_outliers,
|
|
||||||
"is_totals": is_totals}
|
|
||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.stats = None
|
self.stats = None
|
||||||
self.refresh()
|
self.refresh()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
""" Refresh the stats """
|
""" Refresh the stats """
|
||||||
|
logger.debug("Refreshing")
|
||||||
|
if not self.session.initialized:
|
||||||
|
logger.warning("Session data is not initialized. Not refreshing")
|
||||||
|
return
|
||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.stats = self.get_raw()
|
self.stats = self.get_raw()
|
||||||
self.get_calculations()
|
self.get_calculations()
|
||||||
self.remove_raw()
|
self.remove_raw()
|
||||||
|
logger.debug("Refreshed")
|
||||||
|
|
||||||
def get_raw(self):
|
def get_raw(self):
|
||||||
""" Add raw data to stats dict """
|
""" Add raw data to stats dict """
|
||||||
raw = dict()
|
logger.debug("Getting Raw Data")
|
||||||
for idx, item in enumerate(self.args["display"]):
|
|
||||||
if item.lower() == "rate":
|
|
||||||
data = self.calc_rate(self.session)
|
|
||||||
else:
|
|
||||||
data = self.session["loss"][idx][:]
|
|
||||||
|
|
||||||
|
raw = dict()
|
||||||
|
iterations = set()
|
||||||
|
if self.display.lower() == "loss":
|
||||||
|
loss_dict = self.session.total_loss if self.is_totals else self.session.loss
|
||||||
|
for loss_name, side_loss in loss_dict.items():
|
||||||
|
if loss_name not in self.loss_keys:
|
||||||
|
continue
|
||||||
|
for side, loss in side_loss.items():
|
||||||
|
if self.args["flatten_outliers"]:
|
||||||
|
loss = self.flatten_outliers(loss)
|
||||||
|
iterations.add(len(loss))
|
||||||
|
raw["raw_{}_{}".format(loss_name, side)] = loss
|
||||||
|
|
||||||
|
self.iterations = 0 if not iterations else min(iterations)
|
||||||
|
if len(iterations) > 1:
|
||||||
|
# Crop all losses to the same number of items
|
||||||
|
if self.iterations == 0:
|
||||||
|
raw = {lossname: list() for lossname in raw.keys()}
|
||||||
|
else:
|
||||||
|
raw = {lossname: loss[:self.iterations] for lossname, loss in raw}
|
||||||
|
|
||||||
|
else: # Rate calulation
|
||||||
|
data = self.calc_rate_total() if self.is_totals else self.calc_rate()
|
||||||
if self.args["flatten_outliers"]:
|
if self.args["flatten_outliers"]:
|
||||||
data = self.flatten_outliers(data)
|
data = self.flatten_outliers(data)
|
||||||
|
self.iterations = len(data)
|
||||||
|
raw = {"raw_rate": data}
|
||||||
|
|
||||||
if self.iterations == 0:
|
logger.debug("Got Raw Data")
|
||||||
self.iterations = len(data)
|
|
||||||
|
|
||||||
raw["raw_{}".format(item)] = data
|
|
||||||
return raw
|
return raw
|
||||||
|
|
||||||
def remove_raw(self):
|
def remove_raw(self):
|
||||||
""" Remove raw values from stats if not requested """
|
""" Remove raw values from stats if not requested """
|
||||||
if "raw" in self.args["selections"]:
|
if "raw" in self.selections:
|
||||||
return
|
return
|
||||||
|
logger.debug("Removing Raw Data from output")
|
||||||
for key in list(self.stats.keys()):
|
for key in list(self.stats.keys()):
|
||||||
if key.startswith("raw"):
|
if key.startswith("raw"):
|
||||||
del self.stats[key]
|
del self.stats[key]
|
||||||
|
logger.debug("Removed Raw Data from output")
|
||||||
|
|
||||||
def calc_rate(self, data):
|
def calc_rate(self):
|
||||||
|
""" Calculate rate per iteration """
|
||||||
|
logger.debug("Calculating rate")
|
||||||
|
batchsize = self.session.batchsize
|
||||||
|
timestamps = self.session.timestamps
|
||||||
|
iterations = range(len(timestamps) - 1)
|
||||||
|
rate = [batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations]
|
||||||
|
logger.debug("Calculated rate: Item_count: %s", len(rate))
|
||||||
|
return rate
|
||||||
|
|
||||||
|
def calc_rate_total(self):
|
||||||
""" Calculate rate per iteration
|
""" Calculate rate per iteration
|
||||||
NB: For totals, gaps between sessions can be large
|
NB: For totals, gaps between sessions can be large
|
||||||
so time diffeence has to be reset for each session's
|
so time difference has to be reset for each session's
|
||||||
rate calculation """
|
rate calculation """
|
||||||
batchsize = data["batchsize"]
|
logger.debug("Calculating totals rate")
|
||||||
if self.args["is_totals"]:
|
batchsizes = self.session.total_batchsize
|
||||||
split = data["split"]
|
total_timestamps = self.session.total_timestamps
|
||||||
else:
|
|
||||||
batchsize = [batchsize]
|
|
||||||
split = [len(data["timestamps"])]
|
|
||||||
|
|
||||||
prev_split = 0
|
|
||||||
rate = list()
|
rate = list()
|
||||||
|
for sess_id in sorted(total_timestamps.keys()):
|
||||||
for idx, current_split in enumerate(split):
|
batchsize = batchsizes[sess_id]
|
||||||
prev_time = data["timestamps"][prev_split]
|
timestamps = total_timestamps[sess_id]
|
||||||
timestamp_chunk = data["timestamps"][prev_split:current_split]
|
iterations = range(len(timestamps) - 1)
|
||||||
for item in timestamp_chunk:
|
rate.extend([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations])
|
||||||
current_time = item
|
logger.debug("Calculated totals rate: Item_count: %s", len(rate))
|
||||||
timediff = current_time - prev_time
|
|
||||||
iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff
|
|
||||||
rate.append(iter_rate)
|
|
||||||
prev_time = current_time
|
|
||||||
prev_split = current_split
|
|
||||||
|
|
||||||
if self.args["flatten_outliers"]:
|
|
||||||
rate = self.flatten_outliers(rate)
|
|
||||||
return rate
|
return rate
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flatten_outliers(data):
|
def flatten_outliers(data):
|
||||||
""" Remove the outliers from a provided list """
|
""" Remove the outliers from a provided list """
|
||||||
|
logger.debug("Flattening outliers")
|
||||||
retdata = list()
|
retdata = list()
|
||||||
samples = len(data)
|
samples = len(data)
|
||||||
mean = (sum(data) / samples)
|
mean = (sum(data) / samples)
|
||||||
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
|
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
|
||||||
|
logger.debug("samples: %s, mean: %s, limit: %s", samples, mean, limit)
|
||||||
|
|
||||||
for item in data:
|
for idx, item in enumerate(data):
|
||||||
if (mean - limit) <= item <= (mean + limit):
|
if (mean - limit) <= item <= (mean + limit):
|
||||||
retdata.append(item)
|
retdata.append(item)
|
||||||
else:
|
else:
|
||||||
|
logger.debug("Item idx: %s, value: %s flattened to %s", idx, item, mean)
|
||||||
retdata.append(mean)
|
retdata.append(mean)
|
||||||
|
logger.debug("Flattened outliers")
|
||||||
return retdata
|
return retdata
|
||||||
|
|
||||||
def get_calculations(self):
|
def get_calculations(self):
|
||||||
""" Perform the required calculations """
|
""" Perform the required calculations """
|
||||||
for selection in self.get_selections():
|
for selection in self.selections:
|
||||||
if selection[0] == "raw":
|
if selection == "raw":
|
||||||
continue
|
continue
|
||||||
method = getattr(self, "calc_{}".format(selection[0]))
|
logger.debug("Calculating: %s", selection)
|
||||||
key = "{}_{}".format(selection[0], selection[1])
|
method = getattr(self, "calc_{}".format(selection))
|
||||||
raw = self.stats["raw_{}".format(selection[1])]
|
raw_keys = [key for key in self.stats.keys() if key.startswith("raw_")]
|
||||||
self.stats[key] = method(raw)
|
for key in raw_keys:
|
||||||
|
selected_key = "{}_{}".format(selection, key.replace("raw_", ""))
|
||||||
def get_selections(self):
|
self.stats[selected_key] = method(self.stats[key])
|
||||||
""" Compile a list of data to be calculated """
|
|
||||||
for summary in self.args["selections"]:
|
|
||||||
for item in self.args["display"]:
|
|
||||||
yield summary, item
|
|
||||||
|
|
||||||
def calc_avg(self, data):
|
def calc_avg(self, data):
|
||||||
""" Calculate rolling average """
|
""" Calculate rolling average """
|
||||||
|
logger.debug("Calculating Average")
|
||||||
avgs = list()
|
avgs = list()
|
||||||
presample = ceil(self.args["avg_samples"] / 2)
|
presample = ceil(self.args["avg_samples"] / 2)
|
||||||
postsample = self.args["avg_samples"] - presample
|
postsample = self.args["avg_samples"] - presample
|
||||||
|
@ -353,11 +488,13 @@ class Calculations():
|
||||||
avg = sum(data[idx - presample:idx + postsample]) \
|
avg = sum(data[idx - presample:idx + postsample]) \
|
||||||
/ self.args["avg_samples"]
|
/ self.args["avg_samples"]
|
||||||
avgs.append(avg)
|
avgs.append(avg)
|
||||||
|
logger.debug("Calculated Average")
|
||||||
return avgs
|
return avgs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_trend(data):
|
def calc_trend(data):
|
||||||
""" Compile trend data """
|
""" Compile trend data """
|
||||||
|
logger.debug("Calculating Trend")
|
||||||
points = len(data)
|
points = len(data)
|
||||||
if points < 10:
|
if points < 10:
|
||||||
dummy = [None for i in range(points)]
|
dummy = [None for i in range(points)]
|
||||||
|
@ -366,4 +503,5 @@ class Calculations():
|
||||||
fit = np.polyfit(x_range, data, 3)
|
fit = np.polyfit(x_range, data, 3)
|
||||||
poly = np.poly1d(fit)
|
poly = np.poly1d(fit)
|
||||||
trend = poly(x_range)
|
trend = poly(x_range)
|
||||||
|
logger.debug("Calculated Trend")
|
||||||
return trend
|
return trend
|
||||||
|
|
|
@ -42,7 +42,7 @@ class Tooltip:
|
||||||
waittime=400,
|
waittime=400,
|
||||||
wraplength=250):
|
wraplength=250):
|
||||||
|
|
||||||
self.waittime = waittime # in miliseconds, originally 500
|
self.waittime = waittime # in milliseconds, originally 500
|
||||||
self.wraplength = wraplength # in pixels, originally 180
|
self.wraplength = wraplength # in pixels, originally 180
|
||||||
self.widget = widget
|
self.widget = widget
|
||||||
self.text = text
|
self.text = text
|
||||||
|
@ -115,7 +115,7 @@ class Tooltip:
|
||||||
# No further checks will be done.
|
# No further checks will be done.
|
||||||
|
|
||||||
# TIP:
|
# TIP:
|
||||||
# A further mod might automagically augment the
|
# A further mod might auto-magically augment the
|
||||||
# wraplength when the tooltip is too high to be
|
# wraplength when the tooltip is too high to be
|
||||||
# kept inside the screen.
|
# kept inside the screen.
|
||||||
y_1 = 0
|
y_1 = 0
|
||||||
|
|
246
lib/gui/utils.py
246
lib/gui/utils.py
|
@ -1,7 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
""" Utility functions for the GUI """
|
""" Utility functions for the GUI """
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
@ -10,24 +9,50 @@ import tkinter as tk
|
||||||
from tkinter import filedialog, ttk
|
from tkinter import filedialog, ttk
|
||||||
from PIL import Image, ImageTk
|
from PIL import Image, ImageTk
|
||||||
|
|
||||||
|
from lib.Serializer import JSONSerializer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
_CONFIG = None
|
||||||
|
_IMAGES = None
|
||||||
|
|
||||||
|
|
||||||
class Singleton(type):
|
def initialize_config(cli_opts, scaling_factor, pathcache, statusbar, session):
|
||||||
""" Instigate a singleton.
|
""" Initialize the config and add to global constant """
|
||||||
From: https://stackoverflow.com/questions/6760685
|
global _CONFIG # pylint: disable=global-statement
|
||||||
|
if _CONFIG is not None:
|
||||||
|
return
|
||||||
|
logger.debug("Initializing config: (cli_opts: %s, tk_vars: %s, pathcache: %s, statusbar: %s, "
|
||||||
|
"session: %s)", cli_opts, scaling_factor, pathcache, statusbar, session)
|
||||||
|
_CONFIG = Config(cli_opts, scaling_factor, pathcache, statusbar, session)
|
||||||
|
|
||||||
Singletons are often frowned upon.
|
|
||||||
Feel free to instigate a better solution """
|
|
||||||
|
|
||||||
_instances = {}
|
def get_config():
|
||||||
|
""" return the _CONFIG constant """
|
||||||
|
return _CONFIG
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
|
||||||
if cls not in cls._instances:
|
def initialize_images():
|
||||||
cls._instances[cls] = super(Singleton,
|
""" Initialize the config and add to global constant """
|
||||||
cls).__call__(*args,
|
global _IMAGES # pylint: disable=global-statement
|
||||||
**kwargs)
|
if _IMAGES is not None:
|
||||||
return cls._instances[cls]
|
return
|
||||||
|
logger.debug("Initializing images")
|
||||||
|
_IMAGES = Images()
|
||||||
|
|
||||||
|
|
||||||
|
def get_images():
|
||||||
|
""" return the _CONFIG constant """
|
||||||
|
return _IMAGES
|
||||||
|
|
||||||
|
|
||||||
|
def set_slider_rounding(value, var, d_type, round_to, min_max):
|
||||||
|
""" Set the underlying variable to correct number based on slider rounding """
|
||||||
|
if d_type == float:
|
||||||
|
var.set(round(float(value), round_to))
|
||||||
|
else:
|
||||||
|
steps = range(min_max[0], min_max[1] + round_to, round_to)
|
||||||
|
value = min(steps, key=lambda x: abs(x - int(float(value))))
|
||||||
|
var.set(value)
|
||||||
|
|
||||||
|
|
||||||
class FileHandler():
|
class FileHandler():
|
||||||
|
@ -52,8 +77,8 @@ class FileHandler():
|
||||||
("PNG", "*.png"),
|
("PNG", "*.png"),
|
||||||
("TIFF", "*.tif", "*.tiff"),
|
("TIFF", "*.tif", "*.tiff"),
|
||||||
all_files),
|
all_files),
|
||||||
|
"state": (("State files", "*.json"), all_files),
|
||||||
"log": (("Log files", "*.log"), all_files),
|
"log": (("Log files", "*.log"), all_files),
|
||||||
"session": (("Faceswap session files", "*.fss"), all_files),
|
|
||||||
"video": (("Audio Video Interleave", "*.avi"),
|
"video": (("Audio Video Interleave", "*.avi"),
|
||||||
("Flash Video", "*.flv"),
|
("Flash Video", "*.flv"),
|
||||||
("Matroska", "*.mkv"),
|
("Matroska", "*.mkv"),
|
||||||
|
@ -164,11 +189,15 @@ class FileHandler():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class Images(metaclass=Singleton):
|
class Images():
|
||||||
""" Holds locations of images and actual images """
|
""" Holds locations of images and actual images
|
||||||
|
|
||||||
def __init__(self, pathcache=None):
|
Don't call directly. Call get_images()
|
||||||
logger.debug("Initializing %s: (pathcache: '%s'", self.__class__.__name__, pathcache)
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
|
pathcache = get_config().pathcache
|
||||||
self.pathicons = os.path.join(pathcache, "icons")
|
self.pathicons = os.path.join(pathcache, "icons")
|
||||||
self.pathpreview = os.path.join(pathcache, "preview")
|
self.pathpreview = os.path.join(pathcache, "preview")
|
||||||
self.pathoutput = None
|
self.pathoutput = None
|
||||||
|
@ -194,7 +223,7 @@ class Images(metaclass=Singleton):
|
||||||
""" Delete the preview files """
|
""" Delete the preview files """
|
||||||
logger.debug("Deleting previews")
|
logger.debug("Deleting previews")
|
||||||
for item in os.listdir(self.pathpreview):
|
for item in os.listdir(self.pathpreview):
|
||||||
if item.startswith(".gui_preview_") and item.endswith(".jpg"):
|
if item.startswith(".gui_training_preview") and item.endswith(".jpg"):
|
||||||
fullitem = os.path.join(self.pathpreview, item)
|
fullitem = os.path.join(self.pathpreview, item)
|
||||||
logger.debug("Deleting: '%s'", fullitem)
|
logger.debug("Deleting: '%s'", fullitem)
|
||||||
os.remove(fullitem)
|
os.remove(fullitem)
|
||||||
|
@ -210,34 +239,34 @@ class Images(metaclass=Singleton):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_images(imgpath):
|
def get_images(imgpath):
|
||||||
""" Get the images stored within the given directory """
|
""" Get the images stored within the given directory """
|
||||||
logger.debug("Getting images: '%s'", imgpath)
|
logger.trace("Getting images: '%s'", imgpath)
|
||||||
if not os.path.isdir(imgpath):
|
if not os.path.isdir(imgpath):
|
||||||
logger.debug("Folder does not exist")
|
logger.debug("Folder does not exist")
|
||||||
return None
|
return None
|
||||||
files = [os.path.join(imgpath, f)
|
files = [os.path.join(imgpath, f)
|
||||||
for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))]
|
for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))]
|
||||||
logger.debug("Image files: %s", files)
|
logger.trace("Image files: %s", files)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def load_latest_preview(self):
|
def load_latest_preview(self):
|
||||||
""" Load the latest preview image for extract and convert """
|
""" Load the latest preview image for extract and convert """
|
||||||
logger.debug("Loading preview image")
|
logger.trace("Loading preview image")
|
||||||
imagefiles = self.get_images(self.pathoutput)
|
imagefiles = self.get_images(self.pathoutput)
|
||||||
if not imagefiles or len(imagefiles) == 1:
|
if not imagefiles or len(imagefiles) == 1:
|
||||||
logger.debug("No preview to display")
|
logger.debug("No preview to display")
|
||||||
self.previewoutput = None
|
self.previewoutput = None
|
||||||
return
|
return
|
||||||
# Get penultimate file so we don't accidently
|
# Get penultimate file so we don't accidentally
|
||||||
# load a file that is being saved
|
# load a file that is being saved
|
||||||
show_file = sorted(imagefiles, key=os.path.getctime)[-2]
|
show_file = sorted(imagefiles, key=os.path.getctime)[-2]
|
||||||
img = Image.open(show_file)
|
img = Image.open(show_file)
|
||||||
img.thumbnail((768, 432))
|
img.thumbnail((768, 432))
|
||||||
logger.debug("Displaying preview: '%s'", show_file)
|
logger.trace("Displaying preview: '%s'", show_file)
|
||||||
self.previewoutput = (img, ImageTk.PhotoImage(img))
|
self.previewoutput = (img, ImageTk.PhotoImage(img))
|
||||||
|
|
||||||
def load_training_preview(self):
|
def load_training_preview(self):
|
||||||
""" Load the training preview images """
|
""" Load the training preview images """
|
||||||
logger.debug("Loading Training preview images")
|
logger.trace("Loading Training preview images")
|
||||||
imagefiles = self.get_images(self.pathpreview)
|
imagefiles = self.get_images(self.pathpreview)
|
||||||
modified = None
|
modified = None
|
||||||
if not imagefiles:
|
if not imagefiles:
|
||||||
|
@ -250,7 +279,7 @@ class Images(metaclass=Singleton):
|
||||||
name = os.path.splitext(name)[0]
|
name = os.path.splitext(name)[0]
|
||||||
name = name[name.rfind("_") + 1:].title()
|
name = name[name.rfind("_") + 1:].title()
|
||||||
try:
|
try:
|
||||||
logger.debug("Displaying preview: '%s'", img)
|
logger.trace("Displaying preview: '%s'", img)
|
||||||
size = self.get_current_size(name)
|
size = self.get_current_size(name)
|
||||||
self.previewtrain[name] = [Image.open(img), None, modified]
|
self.previewtrain[name] = [Image.open(img), None, modified]
|
||||||
self.resize_image(name, size)
|
self.resize_image(name, size)
|
||||||
|
@ -270,20 +299,20 @@ class Images(metaclass=Singleton):
|
||||||
|
|
||||||
def get_current_size(self, name):
|
def get_current_size(self, name):
|
||||||
""" Return the size of the currently displayed image """
|
""" Return the size of the currently displayed image """
|
||||||
logger.debug("Getting size: '%s'", name)
|
logger.trace("Getting size: '%s'", name)
|
||||||
if not self.previewtrain.get(name, None):
|
if not self.previewtrain.get(name, None):
|
||||||
return None
|
return None
|
||||||
img = self.previewtrain[name][1]
|
img = self.previewtrain[name][1]
|
||||||
if not img:
|
if not img:
|
||||||
return None
|
return None
|
||||||
logger.debug("Got size: (name: '%s', width: '%s', height: '%s')",
|
logger.trace("Got size: (name: '%s', width: '%s', height: '%s')",
|
||||||
name, img.width(), img.height())
|
name, img.width(), img.height())
|
||||||
return img.width(), img.height()
|
return img.width(), img.height()
|
||||||
|
|
||||||
def resize_image(self, name, framesize):
|
def resize_image(self, name, framesize):
|
||||||
""" Resize the training preview image
|
""" Resize the training preview image
|
||||||
based on the passed in frame size """
|
based on the passed in frame size """
|
||||||
logger.debug("Resizing image: (name: '%s', framesize: %s", name, framesize)
|
logger.trace("Resizing image: (name: '%s', framesize: %s", name, framesize)
|
||||||
displayimg = self.previewtrain[name][0]
|
displayimg = self.previewtrain[name][0]
|
||||||
if framesize:
|
if framesize:
|
||||||
frameratio = float(framesize[0]) / float(framesize[1])
|
frameratio = float(framesize[0]) / float(framesize[1])
|
||||||
|
@ -295,7 +324,7 @@ class Images(metaclass=Singleton):
|
||||||
else:
|
else:
|
||||||
scale = framesize[1] / float(displayimg.size[1])
|
scale = framesize[1] / float(displayimg.size[1])
|
||||||
size = (int(displayimg.size[0] * scale), framesize[1])
|
size = (int(displayimg.size[0] * scale), framesize[1])
|
||||||
logger.debug("Scaling: (scale: %s, size: %s", scale, size)
|
logger.trace("Scaling: (scale: %s, size: %s", scale, size)
|
||||||
|
|
||||||
# Hacky fix to force a reload if it happens to find corrupted
|
# Hacky fix to force a reload if it happens to find corrupted
|
||||||
# data, probably due to reading the image whilst it is partially
|
# data, probably due to reading the image whilst it is partially
|
||||||
|
@ -335,7 +364,9 @@ class ContextMenu(tk.Menu): # pylint: disable=too-many-ancestors
|
||||||
""" Bind the menu to the widget's Right Click event """
|
""" Bind the menu to the widget's Right Click event """
|
||||||
button = "<Button-2>" if platform.system() == "Darwin" else "<Button-3>"
|
button = "<Button-2>" if platform.system() == "Darwin" else "<Button-3>"
|
||||||
logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class())
|
logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class())
|
||||||
self.widget.bind(button, lambda event: self.tk_popup(event.x_root, event.y_root, 0))
|
x_offset = int(34 * get_config().scaling_factor)
|
||||||
|
self.widget.bind(button,
|
||||||
|
lambda event: self.tk_popup(event.x_root + x_offset, event.y_root, 0))
|
||||||
|
|
||||||
def select_all(self):
|
def select_all(self):
|
||||||
""" Select all for Text or Entry widgets """
|
""" Select all for Text or Entry widgets """
|
||||||
|
@ -351,16 +382,16 @@ class ContextMenu(tk.Menu): # pylint: disable=too-many-ancestors
|
||||||
class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors
|
class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
""" The Console out section of the GUI """
|
""" The Console out section of the GUI """
|
||||||
|
|
||||||
def __init__(self, parent, debug, tk_vars):
|
def __init__(self, parent, debug):
|
||||||
logger.debug("Initializing %s: (parent: %s, debug: %s, tk_vars: %s)",
|
logger.debug("Initializing %s: (parent: %s, debug: %s)",
|
||||||
self.__class__.__name__, parent, debug, tk_vars)
|
self.__class__.__name__, parent, debug)
|
||||||
ttk.Frame.__init__(self, parent)
|
ttk.Frame.__init__(self, parent)
|
||||||
self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0),
|
self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0),
|
||||||
fill=tk.BOTH, expand=True)
|
fill=tk.BOTH, expand=True)
|
||||||
self.console = tk.Text(self)
|
self.console = tk.Text(self)
|
||||||
rc_menu = ContextMenu(self.console)
|
rc_menu = ContextMenu(self.console)
|
||||||
rc_menu.cm_bind()
|
rc_menu.cm_bind()
|
||||||
self.console_clear = tk_vars['consoleclear']
|
self.console_clear = get_config().tk_vars['consoleclear']
|
||||||
self.set_console_clear_var_trace()
|
self.set_console_clear_var_trace()
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.build_console()
|
self.build_console()
|
||||||
|
@ -395,7 +426,7 @@ class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors
|
||||||
sys.stderr = SysOutRouter(console=self.console, out_type="stderr")
|
sys.stderr = SysOutRouter(console=self.console, out_type="stderr")
|
||||||
logger.debug("Redirected console")
|
logger.debug("Redirected console")
|
||||||
|
|
||||||
def clear(self, *args):
|
def clear(self, *args): # pylint: disable=unused-argument
|
||||||
""" Clear the console output screen """
|
""" Clear the console output screen """
|
||||||
logger.debug("Clear console")
|
logger.debug("Clear console")
|
||||||
if not self.console_clear.get():
|
if not self.console_clear.get():
|
||||||
|
@ -427,3 +458,146 @@ class SysOutRouter():
|
||||||
def flush():
|
def flush():
|
||||||
""" If flush is forced, send it to normal terminal """
|
""" If flush is forced, send it to normal terminal """
|
||||||
sys.__stdout__.flush()
|
sys.__stdout__.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class Config():
|
||||||
|
""" Global configuration settings
|
||||||
|
|
||||||
|
Don't call directly. Call get_config()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cli_opts, scaling_factor, pathcache, statusbar, session):
|
||||||
|
logger.debug("Initializing %s: (cli_opts: %s, scaling_factor: %s, pathcache: %s, "
|
||||||
|
"statusbar: %s, session: %s)", self.__class__.__name__, cli_opts,
|
||||||
|
scaling_factor, pathcache, statusbar, session)
|
||||||
|
self.cli_opts = cli_opts
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.pathcache = pathcache
|
||||||
|
self.statusbar = statusbar
|
||||||
|
self.serializer = JSONSerializer
|
||||||
|
self.tk_vars = self.set_tk_vars()
|
||||||
|
self.command_notebook = None # set in command.py
|
||||||
|
self.session = session
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def command_tabs(self):
|
||||||
|
""" Return dict of command tab titles with their IDs """
|
||||||
|
return {self.command_notebook.tab(tab_id, "text").lower(): tab_id
|
||||||
|
for tab_id in range(0, self.command_notebook.index("end"))}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_tk_vars():
|
||||||
|
""" TK Variables to be triggered by to indicate
|
||||||
|
what state various parts of the GUI should be in """
|
||||||
|
display = tk.StringVar()
|
||||||
|
display.set(None)
|
||||||
|
|
||||||
|
runningtask = tk.BooleanVar()
|
||||||
|
runningtask.set(False)
|
||||||
|
|
||||||
|
actioncommand = tk.StringVar()
|
||||||
|
actioncommand.set(None)
|
||||||
|
|
||||||
|
generatecommand = tk.StringVar()
|
||||||
|
generatecommand.set(None)
|
||||||
|
|
||||||
|
consoleclear = tk.BooleanVar()
|
||||||
|
consoleclear.set(False)
|
||||||
|
|
||||||
|
refreshgraph = tk.BooleanVar()
|
||||||
|
refreshgraph.set(False)
|
||||||
|
|
||||||
|
updatepreview = tk.BooleanVar()
|
||||||
|
updatepreview.set(False)
|
||||||
|
|
||||||
|
tk_vars = {"display": display,
|
||||||
|
"runningtask": runningtask,
|
||||||
|
"action": actioncommand,
|
||||||
|
"generate": generatecommand,
|
||||||
|
"consoleclear": consoleclear,
|
||||||
|
"refreshgraph": refreshgraph,
|
||||||
|
"updatepreview": updatepreview}
|
||||||
|
logger.debug(tk_vars)
|
||||||
|
return tk_vars
|
||||||
|
|
||||||
|
def load(self, command=None, filename=None):
|
||||||
|
""" Pop up load dialog for a saved config file """
|
||||||
|
logger.debug("Loading config: (command: '%s')", command)
|
||||||
|
if filename:
|
||||||
|
with open(filename, "r") as cfgfile:
|
||||||
|
cfg = self.serializer.unmarshal(cfgfile.read())
|
||||||
|
else:
|
||||||
|
cfgfile = FileHandler("open", "config").retfile
|
||||||
|
if not cfgfile:
|
||||||
|
return
|
||||||
|
cfg = self.serializer.unmarshal(cfgfile.read())
|
||||||
|
|
||||||
|
if not command and len(cfg.keys()) == 1:
|
||||||
|
command = list(cfg.keys())[0]
|
||||||
|
|
||||||
|
opts = self.get_command_options(cfg, command) if command else cfg
|
||||||
|
if not opts:
|
||||||
|
return
|
||||||
|
|
||||||
|
for cmd, opts in opts.items():
|
||||||
|
self.set_command_args(cmd, opts)
|
||||||
|
|
||||||
|
if command:
|
||||||
|
self.command_notebook.select(self.command_tabs[command])
|
||||||
|
|
||||||
|
self.add_to_recent(cfgfile.name, command)
|
||||||
|
logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile)
|
||||||
|
|
||||||
|
def get_command_options(self, cfg, command):
|
||||||
|
""" return the saved options for the requested
|
||||||
|
command, if not loading global options """
|
||||||
|
opts = cfg.get(command, None)
|
||||||
|
retval = {command: opts}
|
||||||
|
if not opts:
|
||||||
|
self.tk_vars["consoleclear"].set(True)
|
||||||
|
print("No {} section found in file".format(command))
|
||||||
|
logger.info("No %s section found in file", command)
|
||||||
|
retval = None
|
||||||
|
logger.debug(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def set_command_args(self, command, options):
|
||||||
|
""" Pass the saved config items back to the CliOptions """
|
||||||
|
if not options:
|
||||||
|
return
|
||||||
|
for srcopt, srcval in options.items():
|
||||||
|
optvar = self.cli_opts.get_one_option_variable(command, srcopt)
|
||||||
|
if not optvar:
|
||||||
|
continue
|
||||||
|
optvar.set(srcval)
|
||||||
|
|
||||||
|
def save(self, command=None):
|
||||||
|
""" Save the current GUI state to a config file in json format """
|
||||||
|
logger.debug("Saving config: (command: '%s')", command)
|
||||||
|
cfgfile = FileHandler("save", "config").retfile
|
||||||
|
if not cfgfile:
|
||||||
|
return
|
||||||
|
cfg = self.cli_opts.get_option_values(command)
|
||||||
|
cfgfile.write(self.serializer.marshal(cfg))
|
||||||
|
cfgfile.close()
|
||||||
|
self.add_to_recent(cfgfile.name, command)
|
||||||
|
logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile)
|
||||||
|
|
||||||
|
def add_to_recent(self, filename, command):
|
||||||
|
""" Add to recent files """
|
||||||
|
recent_filename = os.path.join(self.pathcache, ".recent.json")
|
||||||
|
logger.debug("Adding to recent files '%s': (%s, %s)", recent_filename, filename, command)
|
||||||
|
with open(recent_filename, "rb") as inp:
|
||||||
|
recent_files = self.serializer.unmarshal(inp.read().decode("utf-8"))
|
||||||
|
logger.debug("Initial recent files: %s", recent_files)
|
||||||
|
filenames = [recent[0] for recent in recent_files]
|
||||||
|
if filename in filenames:
|
||||||
|
idx = filenames.index(filename)
|
||||||
|
del recent_files[idx]
|
||||||
|
recent_files.insert(0, (filename, command))
|
||||||
|
recent_files = recent_files[:20]
|
||||||
|
logger.debug("Final recent files: %s", recent_files)
|
||||||
|
recent_json = self.serializer.marshal(recent_files)
|
||||||
|
with open(recent_filename, "wb") as out:
|
||||||
|
out.write(recent_json.encode("utf-8"))
|
||||||
|
|
|
@ -1,57 +1,40 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" Process wrapper for underlying faceswap commands for the GUI """
|
""" Process wrapper for underlying faceswap commands for the GUI """
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import signal
|
import signal
|
||||||
from subprocess import PIPE, Popen, TimeoutExpired
|
from subprocess import PIPE, Popen, TimeoutExpired
|
||||||
import sys
|
import sys
|
||||||
import tkinter as tk
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from .utils import Images
|
from .utils import get_config, get_images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ProcessWrapper():
|
class ProcessWrapper():
|
||||||
""" Builds command, launches and terminates the underlying
|
""" Builds command, launches and terminates the underlying
|
||||||
faceswap process. Updates GUI display depending on state """
|
faceswap process. Updates GUI display depending on state """
|
||||||
|
|
||||||
def __init__(self, statusbar, session=None, pathscript=None, cliopts=None):
|
def __init__(self, pathscript=None):
|
||||||
self.tk_vars = self.set_tk_vars()
|
logger.debug("Initializing %s: (pathscript: %s)", self.__class__.__name__, pathscript)
|
||||||
self.session = session
|
self.tk_vars = get_config().tk_vars
|
||||||
|
self.set_callbacks()
|
||||||
self.pathscript = pathscript
|
self.pathscript = pathscript
|
||||||
self.cliopts = cliopts
|
|
||||||
self.command = None
|
self.command = None
|
||||||
self.statusbar = statusbar
|
self.statusbar = get_config().statusbar
|
||||||
self.task = FaceswapControl(self)
|
self.task = FaceswapControl(self)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_tk_vars(self):
|
def set_callbacks(self):
|
||||||
""" TK Variables to be triggered by ProcessWrapper to indicate
|
""" Set the tk variable callbacks """
|
||||||
what state various parts of the GUI should be in """
|
logger.debug("Setting tk variable traces")
|
||||||
display = tk.StringVar()
|
self.tk_vars["action"].trace("w", self.action_command)
|
||||||
display.set(None)
|
self.tk_vars["generate"].trace("w", self.generate_command)
|
||||||
|
|
||||||
runningtask = tk.BooleanVar()
|
|
||||||
runningtask.set(False)
|
|
||||||
|
|
||||||
actioncommand = tk.StringVar()
|
|
||||||
actioncommand.set(None)
|
|
||||||
actioncommand.trace("w", self.action_command)
|
|
||||||
|
|
||||||
generatecommand = tk.StringVar()
|
|
||||||
generatecommand.set(None)
|
|
||||||
generatecommand.trace("w", self.generate_command)
|
|
||||||
|
|
||||||
consoleclear = tk.BooleanVar()
|
|
||||||
consoleclear.set(False)
|
|
||||||
|
|
||||||
return {"display": display,
|
|
||||||
"runningtask": runningtask,
|
|
||||||
"action": actioncommand,
|
|
||||||
"generate": generatecommand,
|
|
||||||
"consoleclear": consoleclear}
|
|
||||||
|
|
||||||
def action_command(self, *args):
|
def action_command(self, *args):
|
||||||
""" The action to perform when the action button is pressed """
|
""" The action to perform when the action button is pressed """
|
||||||
|
@ -74,28 +57,30 @@ class ProcessWrapper():
|
||||||
category, command = self.tk_vars["generate"].get().split(",")
|
category, command = self.tk_vars["generate"].get().split(",")
|
||||||
args = self.build_args(category, command=command, generate=True)
|
args = self.build_args(category, command=command, generate=True)
|
||||||
self.tk_vars["consoleclear"].set(True)
|
self.tk_vars["consoleclear"].set(True)
|
||||||
|
logger.debug(" ".join(args))
|
||||||
print(" ".join(args))
|
print(" ".join(args))
|
||||||
self.tk_vars["generate"].set(None)
|
self.tk_vars["generate"].set(None)
|
||||||
|
|
||||||
def prepare(self, category):
|
def prepare(self, category):
|
||||||
""" Prepare the environment for execution """
|
""" Prepare the environment for execution """
|
||||||
|
logger.debug("Preparing for execution")
|
||||||
self.tk_vars["runningtask"].set(True)
|
self.tk_vars["runningtask"].set(True)
|
||||||
self.tk_vars["consoleclear"].set(True)
|
self.tk_vars["consoleclear"].set(True)
|
||||||
print("Loading...")
|
print("Loading...")
|
||||||
|
|
||||||
self.statusbar.status_message.set("Executing - "
|
self.statusbar.status_message.set("Executing - {}.py".format(self.command))
|
||||||
+ self.command + ".py")
|
mode = "indeterminate" if self.command in ("effmpeg", "train") else "determinate"
|
||||||
mode = "indeterminate" if self.command in ("effmpeg",
|
|
||||||
"train") else "determinate"
|
|
||||||
self.statusbar.progress_start(mode)
|
self.statusbar.progress_start(mode)
|
||||||
|
|
||||||
args = self.build_args(category)
|
args = self.build_args(category)
|
||||||
self.tk_vars["display"].set(self.command)
|
self.tk_vars["display"].set(self.command)
|
||||||
|
logger.debug("Prepared for execution")
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def build_args(self, category, command=None, generate=False):
|
def build_args(self, category, command=None, generate=False):
|
||||||
""" Build the faceswap command and arguments list """
|
""" Build the faceswap command and arguments list """
|
||||||
|
logger.debug("Build cli arguments: (category: %s, command: %s, generate: %s)",
|
||||||
|
category, command, generate)
|
||||||
command = self.command if not command else command
|
command = self.command if not command else command
|
||||||
script = "{}.{}".format(category, "py")
|
script = "{}.{}".format(category, "py")
|
||||||
pathexecscript = os.path.join(self.pathscript, script)
|
pathexecscript = os.path.join(self.pathscript, script)
|
||||||
|
@ -103,50 +88,60 @@ class ProcessWrapper():
|
||||||
args = [sys.executable] if generate else [sys.executable, "-u"]
|
args = [sys.executable] if generate else [sys.executable, "-u"]
|
||||||
args.extend([pathexecscript, command])
|
args.extend([pathexecscript, command])
|
||||||
|
|
||||||
for cliopt in self.cliopts.gen_cli_arguments(command):
|
cli_opts = get_config().cli_opts
|
||||||
|
for cliopt in cli_opts.gen_cli_arguments(command):
|
||||||
args.extend(cliopt)
|
args.extend(cliopt)
|
||||||
if command == "train" and not generate:
|
if command == "train" and not generate:
|
||||||
self.set_session_stats(cliopt)
|
self.init_training_session(cliopt)
|
||||||
if command == "train" and not generate:
|
if not generate:
|
||||||
args.append("-gui") # Embed the preview pane
|
args.append("-gui") # Indicate to Faceswap that we are running the GUI
|
||||||
|
logger.debug("Built cli arguments: (%s)", args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def set_session_stats(self, cliopt):
|
@staticmethod
|
||||||
""" Set the session stats for batchsize and modeldir """
|
def init_training_session(cliopt):
|
||||||
if cliopt[0] == "-bs":
|
""" Set the session stats for disable logging, model folder and model name """
|
||||||
self.session.stats["batchsize"] = int(cliopt[1])
|
session = get_config().session
|
||||||
|
if cliopt[0] == "-t":
|
||||||
|
session.modelname = cliopt[1].lower().replace("-", "_")
|
||||||
|
logger.debug("modelname: '%s'", session.modelname)
|
||||||
if cliopt[0] == "-m":
|
if cliopt[0] == "-m":
|
||||||
self.session.modeldir = cliopt[1]
|
session.modeldir = cliopt[1]
|
||||||
|
logger.debug("modeldir: '%s'", session.modeldir)
|
||||||
|
|
||||||
def terminate(self, message):
|
def terminate(self, message):
|
||||||
""" Finalise wrapper when process has exited """
|
""" Finalize wrapper when process has exited """
|
||||||
|
logger.debug("Terminating Faceswap processes")
|
||||||
self.tk_vars["runningtask"].set(False)
|
self.tk_vars["runningtask"].set(False)
|
||||||
self.statusbar.progress_stop()
|
self.statusbar.progress_stop()
|
||||||
self.statusbar.status_message.set(message)
|
self.statusbar.status_message.set(message)
|
||||||
self.tk_vars["display"].set(None)
|
self.tk_vars["display"].set(None)
|
||||||
Images().delete_preview()
|
get_images().delete_preview()
|
||||||
if self.command == "train":
|
get_config().session.__init__()
|
||||||
self.session.save_session()
|
|
||||||
self.session.__init__()
|
|
||||||
self.command = None
|
self.command = None
|
||||||
|
logger.debug("Terminated Faceswap processes")
|
||||||
print("Process exited.")
|
print("Process exited.")
|
||||||
|
|
||||||
|
|
||||||
class FaceswapControl():
|
class FaceswapControl():
|
||||||
""" Control the underlying Faceswap tasks """
|
""" Control the underlying Faceswap tasks """
|
||||||
def __init__(self, wrapper):
|
def __init__(self, wrapper):
|
||||||
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
self.wrapper = wrapper
|
self.wrapper = wrapper
|
||||||
self.statusbar = wrapper.statusbar
|
self.statusbar = get_config().statusbar
|
||||||
self.command = None
|
self.command = None
|
||||||
self.args = None
|
self.args = None
|
||||||
self.process = None
|
self.process = None
|
||||||
|
self.train_stats = {"iterations": 0, "timestamp": None}
|
||||||
self.consoleregex = {
|
self.consoleregex = {
|
||||||
"loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"),
|
"loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"),
|
||||||
"tqdm": re.compile(r"(\d+%|\d+/\d+|\d+:\d+|\d+\.\d+[a-zA-Z/]+)")}
|
"tqdm": re.compile(r".*?(?P<pct>\d+%).*?(?P<itm>\d+/\d+)\W\["
|
||||||
|
r"(?P<tme>\d+:\d+<.*),\W(?P<rte>.*)[a-zA-Z/]*\]")}
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def execute_script(self, command, args):
|
def execute_script(self, command, args):
|
||||||
""" Execute the requested Faceswap Script """
|
""" Execute the requested Faceswap Script """
|
||||||
|
logger.debug("Executing Faceswap: (command: '%s', args: %s)", command, args)
|
||||||
self.command = command
|
self.command = command
|
||||||
kwargs = {"stdout": PIPE,
|
kwargs = {"stdout": PIPE,
|
||||||
"stderr": PIPE,
|
"stderr": PIPE,
|
||||||
|
@ -156,10 +151,12 @@ class FaceswapControl():
|
||||||
self.process = Popen(args, **kwargs, stdin=PIPE)
|
self.process = Popen(args, **kwargs, stdin=PIPE)
|
||||||
self.thread_stdout()
|
self.thread_stdout()
|
||||||
self.thread_stderr()
|
self.thread_stderr()
|
||||||
|
logger.debug("Executed Faceswap")
|
||||||
|
|
||||||
def read_stdout(self):
|
def read_stdout(self):
|
||||||
""" Read stdout from the subprocess. If training, pass the loss
|
""" Read stdout from the subprocess. If training, pass the loss
|
||||||
values to Queue """
|
values to Queue """
|
||||||
|
logger.debug("Opening stdout reader")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
output = self.process.stdout.readline()
|
output = self.process.stdout.readline()
|
||||||
|
@ -173,14 +170,19 @@ class FaceswapControl():
|
||||||
if (self.command == "train" and self.capture_loss(output)) or (
|
if (self.command == "train" and self.capture_loss(output)) or (
|
||||||
self.command != "train" and self.capture_tqdm(output)):
|
self.command != "train" and self.capture_tqdm(output)):
|
||||||
continue
|
continue
|
||||||
|
if self.command == "train" and output.strip().endswith("saved models"):
|
||||||
|
logger.debug("Trigger update preview")
|
||||||
|
self.wrapper.tk_vars["updatepreview"].set(True)
|
||||||
print(output.strip())
|
print(output.strip())
|
||||||
returncode = self.process.poll()
|
returncode = self.process.poll()
|
||||||
message = self.set_final_status(returncode)
|
message = self.set_final_status(returncode)
|
||||||
self.wrapper.terminate(message)
|
self.wrapper.terminate(message)
|
||||||
|
logger.debug("Terminated stdout reader. returncode: %s", returncode)
|
||||||
|
|
||||||
def read_stderr(self):
|
def read_stderr(self):
|
||||||
""" Read stdout from the subprocess. If training, pass the loss
|
""" Read stdout from the subprocess. If training, pass the loss
|
||||||
values to Queue """
|
values to Queue """
|
||||||
|
logger.debug("Opening stderr reader")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
output = self.process.stderr.readline()
|
output = self.process.stderr.readline()
|
||||||
|
@ -194,81 +196,125 @@ class FaceswapControl():
|
||||||
if self.command != "train" and self.capture_tqdm(output):
|
if self.command != "train" and self.capture_tqdm(output):
|
||||||
continue
|
continue
|
||||||
print(output.strip(), file=sys.stderr)
|
print(output.strip(), file=sys.stderr)
|
||||||
|
logger.debug("Terminated stderr reader")
|
||||||
|
|
||||||
def thread_stdout(self):
|
def thread_stdout(self):
|
||||||
""" Put the subprocess stdout so that it can be read without
|
""" Put the subprocess stdout so that it can be read without
|
||||||
blocking """
|
blocking """
|
||||||
|
logger.debug("Threading stdout")
|
||||||
thread = Thread(target=self.read_stdout)
|
thread = Thread(target=self.read_stdout)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
logger.debug("Threaded stdout")
|
||||||
|
|
||||||
def thread_stderr(self):
|
def thread_stderr(self):
|
||||||
""" Put the subprocess stderr so that it can be read without
|
""" Put the subprocess stderr so that it can be read without
|
||||||
blocking """
|
blocking """
|
||||||
|
logger.debug("Threading stderr")
|
||||||
thread = Thread(target=self.read_stderr)
|
thread = Thread(target=self.read_stderr)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
logger.debug("Threaded stderr")
|
||||||
|
|
||||||
def capture_loss(self, string):
|
def capture_loss(self, string):
|
||||||
""" Capture loss values from stdout """
|
""" Capture loss values from stdout """
|
||||||
|
logger.trace("Capturing loss")
|
||||||
if not str.startswith(string, "["):
|
if not str.startswith(string, "["):
|
||||||
|
logger.trace("Not loss message. Returning False")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
loss = self.consoleregex["loss"].findall(string)
|
loss = self.consoleregex["loss"].findall(string)
|
||||||
if len(loss) < 2:
|
if len(loss) < 2:
|
||||||
|
logger.trace("Not loss message. Returning False")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.wrapper.session.add_loss(loss)
|
|
||||||
|
|
||||||
message = ""
|
message = ""
|
||||||
for item in loss:
|
for item in loss:
|
||||||
message += "{}: {} ".format(item[0], item[1])
|
message += "{}: {} ".format(item[0], item[1])
|
||||||
if not message:
|
if not message:
|
||||||
|
logger.trace("Error creating loss message. Returning False")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elapsed = self.wrapper.session.timestats["elapsed"]
|
iterations = self.train_stats["iterations"]
|
||||||
iterations = self.wrapper.session.stats["iterations"]
|
|
||||||
|
|
||||||
|
if iterations == 0:
|
||||||
|
# Initialize session stats and set initial timestamp
|
||||||
|
self.train_stats["timestamp"] = time()
|
||||||
|
|
||||||
|
if not get_config().session.initialized and iterations > 0:
|
||||||
|
# Don't initialize session until after the first iteration as state
|
||||||
|
# file must exist first
|
||||||
|
get_config().session.initialize_session(is_training=True)
|
||||||
|
self.wrapper.tk_vars["refreshgraph"].set(True)
|
||||||
|
|
||||||
|
iterations += 1
|
||||||
|
if iterations % 100 == 0:
|
||||||
|
self.wrapper.tk_vars["refreshgraph"].set(True)
|
||||||
|
self.train_stats["iterations"] = iterations
|
||||||
|
|
||||||
|
elapsed = self.calc_elapsed()
|
||||||
message = "Elapsed: {} Iteration: {} {}".format(elapsed,
|
message = "Elapsed: {} Iteration: {} {}".format(elapsed,
|
||||||
iterations,
|
self.train_stats["iterations"], message)
|
||||||
message)
|
|
||||||
self.statusbar.progress_update(message, 0, False)
|
self.statusbar.progress_update(message, 0, False)
|
||||||
|
logger.trace("Succesfully captured loss: %s", message)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def calc_elapsed(self):
|
||||||
|
""" Calculate and format time since training started """
|
||||||
|
now = time()
|
||||||
|
elapsed_time = now - self.train_stats["timestamp"]
|
||||||
|
try:
|
||||||
|
hrs = int(elapsed_time // 3600)
|
||||||
|
if hrs < 10:
|
||||||
|
hrs = "{0:02d}".format(hrs)
|
||||||
|
mins = "{0:02d}".format((int(elapsed_time % 3600) // 60))
|
||||||
|
secs = "{0:02d}".format((int(elapsed_time % 3600) % 60))
|
||||||
|
except ZeroDivisionError:
|
||||||
|
hrs = "00"
|
||||||
|
mins = "00"
|
||||||
|
secs = "00"
|
||||||
|
return "{}:{}:{}".format(hrs, mins, secs)
|
||||||
|
|
||||||
def capture_tqdm(self, string):
|
def capture_tqdm(self, string):
|
||||||
""" Capture tqdm output for progress bar """
|
""" Capture tqdm output for progress bar """
|
||||||
tqdm = self.consoleregex["tqdm"].findall(string)
|
logger.trace("Capturing tqdm")
|
||||||
if len(tqdm) != 5:
|
tqdm = self.consoleregex["tqdm"].match(string)
|
||||||
|
if not tqdm:
|
||||||
return False
|
return False
|
||||||
|
tqdm = tqdm.groupdict()
|
||||||
percent = tqdm[0]
|
if any("?" in val for val in tqdm.values()):
|
||||||
processed = tqdm[1]
|
logger.trace("tqdm initializing. Skipping")
|
||||||
processtime = "Elapsed: {} Remaining: {}".format(tqdm[2], tqdm[3])
|
return True
|
||||||
rate = tqdm[4]
|
processtime = "Elapsed: {} Remaining: {}".format(tqdm["tme"].split("<")[0],
|
||||||
|
tqdm["tme"].split("<")[1])
|
||||||
message = "{} | {} | {} | {}".format(processtime,
|
message = "{} | {} | {} | {}".format(processtime,
|
||||||
rate,
|
tqdm["rte"],
|
||||||
processed,
|
tqdm["itm"],
|
||||||
percent)
|
tqdm["pct"])
|
||||||
|
|
||||||
current, total = processed.split("/")
|
current, total = tqdm["itm"].split("/")
|
||||||
position = int((float(current) / float(total)) * 1000)
|
position = int((float(current) / float(total)) * 1000)
|
||||||
|
|
||||||
self.statusbar.progress_update(message, position, True)
|
self.statusbar.progress_update(message, position, True)
|
||||||
|
logger.trace("Succesfully captured tqdm message: %s", message)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def terminate(self):
|
def terminate(self):
|
||||||
""" Terminate the subprocess """
|
""" Terminate the subprocess """
|
||||||
if self.command != "train":
|
logger.debug("Terminating wrapper")
|
||||||
|
if self.command == "train":
|
||||||
|
logger.debug("Sending Exit Signal")
|
||||||
print("Sending Exit Signal", flush=True)
|
print("Sending Exit Signal", flush=True)
|
||||||
try:
|
try:
|
||||||
now = time()
|
now = time()
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
try:
|
try:
|
||||||
|
logger.debug("Sending carriage return to process")
|
||||||
self.process.communicate(input="\n", timeout=60)
|
self.process.communicate(input="\n", timeout=60)
|
||||||
except TimeoutExpired:
|
except TimeoutExpired:
|
||||||
raise ValueError("Timeout reached sending Exit Signal")
|
raise ValueError("Timeout reached sending Exit Signal")
|
||||||
else:
|
else:
|
||||||
|
logger.debug("Sending SIGINT to process")
|
||||||
self.process.send_signal(signal.SIGINT)
|
self.process.send_signal(signal.SIGINT)
|
||||||
while True:
|
while True:
|
||||||
timeelapsed = time() - now
|
timeelapsed = time() - now
|
||||||
|
@ -278,30 +324,37 @@ class FaceswapControl():
|
||||||
raise ValueError("Timeout reached sending Exit Signal")
|
raise ValueError("Timeout reached sending Exit Signal")
|
||||||
return
|
return
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
logger.error("Error terminating process", exc_info=True)
|
||||||
print(err)
|
print(err)
|
||||||
else:
|
else:
|
||||||
|
logger.debug("Terminating Process...")
|
||||||
print("Terminating Process...")
|
print("Terminating Process...")
|
||||||
children = psutil.Process().children(recursive=True)
|
children = psutil.Process().children(recursive=True)
|
||||||
for child in children:
|
for child in children:
|
||||||
child.terminate()
|
child.terminate()
|
||||||
_, alive = psutil.wait_procs(children, timeout=10)
|
_, alive = psutil.wait_procs(children, timeout=10)
|
||||||
if not alive:
|
if not alive:
|
||||||
|
logger.debug("Terminated")
|
||||||
print("Terminated")
|
print("Terminated")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.debug("Termination timed out. Killing Process...")
|
||||||
print("Termination timed out. Killing Process...")
|
print("Termination timed out. Killing Process...")
|
||||||
for child in alive:
|
for child in alive:
|
||||||
child.kill()
|
child.kill()
|
||||||
_, alive = psutil.wait_procs(alive, timeout=10)
|
_, alive = psutil.wait_procs(alive, timeout=10)
|
||||||
if not alive:
|
if not alive:
|
||||||
|
logger.debug("Killed")
|
||||||
print("Killed")
|
print("Killed")
|
||||||
else:
|
else:
|
||||||
for child in alive:
|
for child in alive:
|
||||||
print("Process {} survived SIGKILL. "
|
msg = "Process {} survived SIGKILL. Giving up".format(child)
|
||||||
"Giving up".format(child))
|
logger.debug(msg)
|
||||||
|
print(msg)
|
||||||
|
|
||||||
def set_final_status(self, returncode):
|
def set_final_status(self, returncode):
|
||||||
""" Set the status bar output based on subprocess return code """
|
""" Set the status bar output based on subprocess return code """
|
||||||
|
logger.debug("Setting final status. returncode: %s", returncode)
|
||||||
if returncode in (0, 3221225786):
|
if returncode in (0, 3221225786):
|
||||||
status = "Ready"
|
status = "Ready"
|
||||||
elif returncode == -15:
|
elif returncode == -15:
|
||||||
|
@ -311,6 +364,6 @@ class FaceswapControl():
|
||||||
elif returncode == -6:
|
elif returncode == -6:
|
||||||
status = "Aborted - {}.py".format(self.command)
|
status = "Aborted - {}.py".format(self.command)
|
||||||
else:
|
else:
|
||||||
status = "Failed - {}.py. Return Code: {}".format(self.command,
|
status = "Failed - {}.py. Return Code: {}".format(self.command, returncode)
|
||||||
returncode)
|
logger.debug("Set final status: %s", status)
|
||||||
return status
|
return status
|
||||||
|
|
92
lib/keypress.py
Normal file
92
lib/keypress.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Source: http://home.wlu.edu/~levys/software/kbhit.py
|
||||||
|
A Python class implementing KBHIT, the standard keyboard-interrupt poller.
|
||||||
|
Works transparently on Windows and Posix (Linux, Mac OS X). Doesn't work
|
||||||
|
with IDLE.
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Lesser General Public License as
|
||||||
|
published by the Free Software Foundation, either version 3 of the
|
||||||
|
License, or (at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
if os.name == "nt":
|
||||||
|
import msvcrt # pylint: disable=import-error
|
||||||
|
|
||||||
|
# Posix (Linux, OS X)
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
import termios
|
||||||
|
import atexit
|
||||||
|
from select import select
|
||||||
|
|
||||||
|
|
||||||
|
class KBHit:
|
||||||
|
""" Creates a KBHit object that you can call to do various keyboard things. """
|
||||||
|
def __init__(self, is_gui=False):
|
||||||
|
self.is_gui = is_gui
|
||||||
|
if os.name == "nt" or self.is_gui:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Save the terminal settings
|
||||||
|
self.file_desc = sys.stdin.fileno()
|
||||||
|
self.new_term = termios.tcgetattr(self.file_desc)
|
||||||
|
self.old_term = termios.tcgetattr(self.file_desc)
|
||||||
|
|
||||||
|
# New terminal setting unbuffered
|
||||||
|
self.new_term[3] = (self.new_term[3] & ~termios.ICANON & ~termios.ECHO)
|
||||||
|
termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.new_term)
|
||||||
|
|
||||||
|
# Support normal-terminal reset at exit
|
||||||
|
atexit.register(self.set_normal_term)
|
||||||
|
|
||||||
|
def set_normal_term(self):
|
||||||
|
""" Resets to normal terminal. On Windows this is a no-op. """
|
||||||
|
if os.name == "nt" or self.is_gui:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.old_term)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getch():
|
||||||
|
""" Returns a keyboard character after kbhit() has been called.
|
||||||
|
Should not be called in the same program as getarrow(). """
|
||||||
|
if os.name == "nt":
|
||||||
|
return msvcrt.getch().decode("utf-8")
|
||||||
|
return sys.stdin.read(1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getarrow():
|
||||||
|
""" Returns an arrow-key code after kbhit() has been called. Codes are
|
||||||
|
0 : up
|
||||||
|
1 : right
|
||||||
|
2 : down
|
||||||
|
3 : left
|
||||||
|
Should not be called in the same program as getch(). """
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
msvcrt.getch() # skip 0xE0
|
||||||
|
char = msvcrt.getch()
|
||||||
|
vals = [72, 77, 80, 75]
|
||||||
|
else:
|
||||||
|
char = sys.stdin.read(3)[2]
|
||||||
|
vals = [65, 67, 66, 68]
|
||||||
|
|
||||||
|
return vals.index(ord(char.decode("utf-8")))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kbhit():
|
||||||
|
""" Returns True if keyboard character was hit, False otherwise. """
|
||||||
|
if os.name == "nt":
|
||||||
|
return msvcrt.kbhit()
|
||||||
|
d_r, _, _ = select([sys.stdin], [], [], 0)
|
||||||
|
return d_r != []
|
|
@ -44,9 +44,15 @@ class MultiProcessingLogger(logging.Logger):
|
||||||
|
|
||||||
|
|
||||||
class FaceswapFormatter(logging.Formatter):
|
class FaceswapFormatter(logging.Formatter):
|
||||||
""" Override formatter to strip newlines and multiple spaces from logger """
|
""" Override formatter to strip newlines and multiple spaces from logger
|
||||||
|
Messages that begin with "R|" should be handled as is
|
||||||
|
"""
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r"))
|
if record.msg.startswith("R|"):
|
||||||
|
record.msg = record.msg[2:]
|
||||||
|
record.strip_spaces = False
|
||||||
|
elif record.strip_spaces:
|
||||||
|
record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r"))
|
||||||
return super().format(record)
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +98,7 @@ def file_handler(loglevel, logfile, log_format, command):
|
||||||
filename = logfile
|
filename = logfile
|
||||||
else:
|
else:
|
||||||
filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap")
|
filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap")
|
||||||
# Windows has issues sharing the log file with subprocesses, so log GUI seperately
|
# Windows has issues sharing the log file with subprocesses, so log GUI separately
|
||||||
filename += "_gui.log" if command == "gui" else ".log"
|
filename += "_gui.log" if command == "gui" else ".log"
|
||||||
|
|
||||||
should_rotate = os.path.isfile(filename)
|
should_rotate = os.path.isfile(filename)
|
||||||
|
@ -152,6 +158,18 @@ def crash_log():
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
# Add a flag to logging.LogRecord to not strip formatting from particular records
|
||||||
|
old_factory = logging.getLogRecordFactory()
|
||||||
|
|
||||||
|
|
||||||
|
def faceswap_logrecord(*args, **kwargs):
|
||||||
|
record = old_factory(*args, **kwargs)
|
||||||
|
record.strip_spaces = True
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
logging.setLogRecordFactory(faceswap_logrecord)
|
||||||
|
|
||||||
# Set logger class to custom logger
|
# Set logger class to custom logger
|
||||||
logging.setLoggerClass(MultiProcessingLogger)
|
logging.setLoggerClass(MultiProcessingLogger)
|
||||||
|
|
||||||
|
|
0
lib/model/__init__.py
Normal file
0
lib/model/__init__.py
Normal file
81
lib/model/initializers.py
Normal file
81
lib/model/initializers.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Custom Initializers for faceswap.py
|
||||||
|
Initializers from:
|
||||||
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
import tensorflow as tf
|
||||||
|
from keras import initializers
|
||||||
|
from keras.utils.generic_utils import get_custom_objects
|
||||||
|
|
||||||
|
|
||||||
|
def icnr_keras(shape, dtype=None):
|
||||||
|
"""
|
||||||
|
Custom initializer for subpix upscaling
|
||||||
|
From https://github.com/kostyaev/ICNR
|
||||||
|
Note: upscale factor is fixed to 2, and the base initializer is fixed to random normal.
|
||||||
|
"""
|
||||||
|
# TODO Roll this into ICNR_init when porting GAN 2.2
|
||||||
|
shape = list(shape)
|
||||||
|
scale = 2
|
||||||
|
initializer = tf.keras.initializers.RandomNormal(0, 0.02)
|
||||||
|
|
||||||
|
new_shape = shape[:3] + [int(shape[3] / (scale ** 2))]
|
||||||
|
var_x = initializer(new_shape, dtype)
|
||||||
|
var_x = tf.transpose(var_x, perm=[2, 0, 1, 3])
|
||||||
|
var_x = tf.image.resize_nearest_neighbor(var_x, size=(shape[0] * scale, shape[1] * scale))
|
||||||
|
var_x = tf.space_to_depth(var_x, block_size=scale)
|
||||||
|
var_x = tf.transpose(var_x, perm=[1, 2, 0, 3])
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
class ICNR(initializers.Initializer): # pylint: disable=invalid-name
|
||||||
|
'''
|
||||||
|
ICNR initializer for checkerboard artifact free sub pixel convolution
|
||||||
|
|
||||||
|
Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution
|
||||||
|
https://arxiv.org/pdf/1707.02937.pdf https://distill.pub/2016/deconv-checkerboard/
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
initializer: initializer used for sub kernels (orthogonal, glorot uniform, etc.)
|
||||||
|
scale: scale factor of sub pixel convolution (upsampling from 8x8 to 16x16 is scale 2)
|
||||||
|
Return:
|
||||||
|
The modified kernel weights
|
||||||
|
Example:
|
||||||
|
x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2))
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, initializer, scale=2):
|
||||||
|
self.scale = scale
|
||||||
|
self.initializer = initializer
|
||||||
|
|
||||||
|
def __call__(self, shape, dtype='float32'): # tf needs partition_info=None
|
||||||
|
shape = list(shape)
|
||||||
|
if self.scale == 1:
|
||||||
|
return self.initializer(shape)
|
||||||
|
new_shape = shape[:3] + [shape[3] // (self.scale ** 2)]
|
||||||
|
if type(self.initializer) is dict:
|
||||||
|
self.initializer = initializers.deserialize(self.initializer)
|
||||||
|
var_x = self.initializer(new_shape, dtype)
|
||||||
|
var_x = tf.transpose(var_x, perm=[2, 0, 1, 3])
|
||||||
|
var_x = tf.image.resize_nearest_neighbor(
|
||||||
|
var_x,
|
||||||
|
size=(shape[0] * self.scale, shape[1] * self.scale),
|
||||||
|
align_corners=True)
|
||||||
|
var_x = tf.space_to_depth(var_x, block_size=self.scale, data_format='NHWC')
|
||||||
|
var_x = tf.transpose(var_x, perm=[1, 2, 0, 3])
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'scale': self.scale,
|
||||||
|
'initializer': self.initializer
|
||||||
|
}
|
||||||
|
base_config = super(ICNR, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
# Update initializers into Keras custom objects
|
||||||
|
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||||
|
if inspect.isclass(obj) and obj.__module__ == __name__:
|
||||||
|
get_custom_objects().update({name: obj})
|
338
lib/model/layers.py
Normal file
338
lib/model/layers.py
Normal file
|
@ -0,0 +1,338 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Custom Layers for faceswap.py
|
||||||
|
Layers from:
|
||||||
|
the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
||||||
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import keras.backend as K
|
||||||
|
|
||||||
|
from keras.engine import InputSpec, Layer
|
||||||
|
from keras.utils import conv_utils
|
||||||
|
from keras.utils.generic_utils import get_custom_objects
|
||||||
|
from keras import initializers
|
||||||
|
from keras.layers import ZeroPadding2D
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffler(Layer):
|
||||||
|
""" PixelShuffler layer for Keras
|
||||||
|
by t-ae: https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 """
|
||||||
|
# pylint: disable=C0103
|
||||||
|
def __init__(self, size=(2, 2), data_format=None, **kwargs):
|
||||||
|
super(PixelShuffler, self).__init__(**kwargs)
|
||||||
|
self.data_format = K.normalize_data_format(data_format)
|
||||||
|
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
||||||
|
|
||||||
|
def call(self, inputs, **kwargs):
|
||||||
|
|
||||||
|
input_shape = K.int_shape(inputs)
|
||||||
|
if len(input_shape) != 4:
|
||||||
|
raise ValueError('Inputs should have rank ' +
|
||||||
|
str(4) +
|
||||||
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
if self.data_format == 'channels_first':
|
||||||
|
batch_size, c, h, w = input_shape
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = -1
|
||||||
|
rh, rw = self.size
|
||||||
|
oh, ow = h * rh, w * rw
|
||||||
|
oc = c // (rh * rw)
|
||||||
|
|
||||||
|
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
|
||||||
|
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
|
||||||
|
out = K.reshape(out, (batch_size, oc, oh, ow))
|
||||||
|
elif self.data_format == 'channels_last':
|
||||||
|
batch_size, h, w, c = input_shape
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = -1
|
||||||
|
rh, rw = self.size
|
||||||
|
oh, ow = h * rh, w * rw
|
||||||
|
oc = c // (rh * rw)
|
||||||
|
|
||||||
|
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
|
||||||
|
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
||||||
|
out = K.reshape(out, (batch_size, oh, ow, oc))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
|
||||||
|
if len(input_shape) != 4:
|
||||||
|
raise ValueError('Inputs should have rank ' +
|
||||||
|
str(4) +
|
||||||
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
if self.data_format == 'channels_first':
|
||||||
|
height = None
|
||||||
|
width = None
|
||||||
|
if input_shape[2] is not None:
|
||||||
|
height = input_shape[2] * self.size[0]
|
||||||
|
if input_shape[3] is not None:
|
||||||
|
width = input_shape[3] * self.size[1]
|
||||||
|
channels = input_shape[1] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
|
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
||||||
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
|
retval = (input_shape[0],
|
||||||
|
channels,
|
||||||
|
height,
|
||||||
|
width)
|
||||||
|
elif self.data_format == 'channels_last':
|
||||||
|
height = None
|
||||||
|
width = None
|
||||||
|
if input_shape[1] is not None:
|
||||||
|
height = input_shape[1] * self.size[0]
|
||||||
|
if input_shape[2] is not None:
|
||||||
|
width = input_shape[2] * self.size[1]
|
||||||
|
channels = input_shape[3] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
|
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
||||||
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
|
retval = (input_shape[0],
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
channels)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'size': self.size,
|
||||||
|
'data_format': self.data_format}
|
||||||
|
base_config = super(PixelShuffler, self).get_config()
|
||||||
|
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
class Scale(Layer):
|
||||||
|
"""
|
||||||
|
GAN Custom Scal Layer
|
||||||
|
Code borrows from https://github.com/flyyufelix/cnn_finetune
|
||||||
|
"""
|
||||||
|
def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):
|
||||||
|
self.axis = axis
|
||||||
|
self.gamma_init = initializers.get(gamma_init)
|
||||||
|
self.initial_weights = weights
|
||||||
|
super(Scale, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.input_spec = [InputSpec(shape=input_shape)]
|
||||||
|
|
||||||
|
# Compatibility with TensorFlow >= 1.0.0
|
||||||
|
self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))
|
||||||
|
self.trainable_weights = [self.gamma]
|
||||||
|
|
||||||
|
if self.initial_weights is not None:
|
||||||
|
self.set_weights(self.initial_weights)
|
||||||
|
del self.initial_weights
|
||||||
|
|
||||||
|
def call(self, x, mask=None):
|
||||||
|
return self.gamma * x
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {"axis": self.axis}
|
||||||
|
base_config = super(Scale, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
class SubPixelUpscaling(Layer):
|
||||||
|
# pylint: disable=C0103
|
||||||
|
""" Sub-pixel convolutional upscaling layer based on the paper "Real-Time
|
||||||
|
Single Image and Video Super-Resolution Using an Efficient Sub-Pixel
|
||||||
|
Convolutional Neural Network" (https://arxiv.org/abs/1609.05158).
|
||||||
|
This layer requires a Convolution2D prior to it, having output filters
|
||||||
|
computed according to the formula :
|
||||||
|
filters = k * (scale_factor * scale_factor)
|
||||||
|
where k = a user defined number of filters (generally larger than 32)
|
||||||
|
scale_factor = the upscaling factor (generally 2)
|
||||||
|
This layer performs the depth to space operation on the convolution
|
||||||
|
filters, and returns a tensor with the size as defined below.
|
||||||
|
# Example :
|
||||||
|
```python
|
||||||
|
# A standard subpixel upscaling block
|
||||||
|
x = Convolution2D(256, 3, 3, padding="same", activation="relu")(...)
|
||||||
|
u = SubPixelUpscaling(scale_factor=2)(x)
|
||||||
|
[Optional]
|
||||||
|
x = Convolution2D(256, 3, 3, padding="same", activation="relu")(u)
|
||||||
|
```
|
||||||
|
In practice, it is useful to have a second convolution layer after the
|
||||||
|
SubPixelUpscaling layer to speed up the learning process.
|
||||||
|
However, if you are stacking multiple SubPixelUpscaling blocks,
|
||||||
|
it may increase the number of parameters greatly, so the Convolution
|
||||||
|
layer after SubPixelUpscaling layer can be removed.
|
||||||
|
# Arguments
|
||||||
|
scale_factor: Upscaling factor.
|
||||||
|
data_format: Can be None, "channels_first" or "channels_last".
|
||||||
|
# Input shape
|
||||||
|
4D tensor with shape:
|
||||||
|
`(samples, k * (scale_factor * scale_factor) channels, rows, cols)`
|
||||||
|
if data_format="channels_first"
|
||||||
|
or 4D tensor with shape:
|
||||||
|
`(samples, rows, cols, k * (scale_factor * scale_factor) channels)`
|
||||||
|
if data_format="channels_last".
|
||||||
|
# Output shape
|
||||||
|
4D tensor with shape:
|
||||||
|
`(samples, k channels, rows * scale_factor, cols * scale_factor))`
|
||||||
|
if data_format="channels_first"
|
||||||
|
or 4D tensor with shape:
|
||||||
|
`(samples, rows * scale_factor, cols * scale_factor, k channels)`
|
||||||
|
if data_format="channels_last".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scale_factor=2, data_format=None, **kwargs):
|
||||||
|
super(SubPixelUpscaling, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.data_format = K.normalize_data_format(data_format)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def call(self, x, mask=None):
|
||||||
|
y = self.depth_to_space(x, self.scale_factor, self.data_format)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
if self.data_format == "channels_first":
|
||||||
|
b, k, r, c = input_shape
|
||||||
|
return (b,
|
||||||
|
k // (self.scale_factor ** 2),
|
||||||
|
r * self.scale_factor,
|
||||||
|
c * self.scale_factor)
|
||||||
|
b, r, c, k = input_shape
|
||||||
|
return (b,
|
||||||
|
r * self.scale_factor,
|
||||||
|
c * self.scale_factor,
|
||||||
|
k // (self.scale_factor ** 2))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def depth_to_space(cls, ipt, scale, data_format=None):
|
||||||
|
""" Uses phase shift algorithm to convert channels/depth
|
||||||
|
for spatial resolution """
|
||||||
|
if data_format is None:
|
||||||
|
data_format = K.image_data_format()
|
||||||
|
data_format = data_format.lower()
|
||||||
|
ipt = cls._preprocess_conv2d_input(ipt, data_format)
|
||||||
|
out = tf.depth_to_space(ipt, scale)
|
||||||
|
out = cls._postprocess_conv2d_output(out, data_format)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _postprocess_conv2d_output(x, data_format):
|
||||||
|
"""Transpose and cast the output from conv2d if needed.
|
||||||
|
# Arguments
|
||||||
|
x: A tensor.
|
||||||
|
data_format: string, `"channels_last"` or `"channels_first"`.
|
||||||
|
# Returns
|
||||||
|
A tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
x = tf.transpose(x, (0, 3, 1, 2))
|
||||||
|
|
||||||
|
if K.floatx() == "float64":
|
||||||
|
x = tf.cast(x, "float64")
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preprocess_conv2d_input(x, data_format):
|
||||||
|
"""Transpose and cast the input before the conv2d.
|
||||||
|
# Arguments
|
||||||
|
x: input tensor.
|
||||||
|
data_format: string, `"channels_last"` or `"channels_first"`.
|
||||||
|
# Returns
|
||||||
|
A tensor.
|
||||||
|
"""
|
||||||
|
if K.dtype(x) == "float64":
|
||||||
|
x = tf.cast(x, "float32")
|
||||||
|
if data_format == "channels_first":
|
||||||
|
# TF uses the last dimension as channel dimension,
|
||||||
|
# instead of the 2nd one.
|
||||||
|
# TH input shape: (samples, input_depth, rows, cols)
|
||||||
|
# TF input shape: (samples, rows, cols, input_depth)
|
||||||
|
x = tf.transpose(x, (0, 2, 3, 1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {"scale_factor": self.scale_factor,
|
||||||
|
"data_format": self.data_format}
|
||||||
|
base_config = super(SubPixelUpscaling, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectionPadding2D(Layer):
|
||||||
|
def __init__(self, stride=2, kernel_size=5, **kwargs):
|
||||||
|
'''
|
||||||
|
# Arguments
|
||||||
|
stride: stride of following convolution (2)
|
||||||
|
kernel_size: kernel size of following convolution (5,5)
|
||||||
|
'''
|
||||||
|
self.stride = stride
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
super(ReflectionPadding2D, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.input_spec = [InputSpec(shape=input_shape)]
|
||||||
|
super(ReflectionPadding2D, self).build(input_shape)
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
""" If you are using "channels_last" configuration"""
|
||||||
|
input_shape = self.input_spec[0].shape
|
||||||
|
in_width, in_height = input_shape[2], input_shape[1]
|
||||||
|
kernel_width, kernel_height = self.kernel_size, self.kernel_size
|
||||||
|
|
||||||
|
if (in_height % self.stride == 0):
|
||||||
|
padding_height = max(kernel_height - self.stride, 0)
|
||||||
|
else:
|
||||||
|
padding_height = max(kernel_height - (in_height % self.stride), 0)
|
||||||
|
if (in_width % self.stride == 0):
|
||||||
|
padding_width = max(kernel_width - self.stride, 0)
|
||||||
|
else:
|
||||||
|
padding_width = max(kernel_width- (in_width % self.stride), 0)
|
||||||
|
|
||||||
|
return (input_shape[0],
|
||||||
|
input_shape[1] + padding_height,
|
||||||
|
input_shape[2] + padding_width,
|
||||||
|
input_shape[3])
|
||||||
|
|
||||||
|
def call(self, x, mask=None):
|
||||||
|
input_shape = self.input_spec[0].shape
|
||||||
|
in_width, in_height = input_shape[2], input_shape[1]
|
||||||
|
kernel_width, kernel_height = self.kernel_size, self.kernel_size
|
||||||
|
|
||||||
|
if (in_height % self.stride == 0):
|
||||||
|
padding_height = max(kernel_height - self.stride, 0)
|
||||||
|
else:
|
||||||
|
padding_height = max(kernel_height - (in_height % self.stride), 0)
|
||||||
|
if (in_width % self.stride == 0):
|
||||||
|
padding_width = max(kernel_width - self.stride, 0)
|
||||||
|
else:
|
||||||
|
padding_width = max(kernel_width- (in_width % self.stride), 0)
|
||||||
|
|
||||||
|
padding_top = padding_height // 2
|
||||||
|
padding_bot = padding_height - padding_top
|
||||||
|
padding_left = padding_width // 2
|
||||||
|
padding_right = padding_width - padding_left
|
||||||
|
|
||||||
|
return tf.pad(x, [[0,0],
|
||||||
|
[padding_top, padding_bot],
|
||||||
|
[padding_left, padding_right],
|
||||||
|
[0,0] ],
|
||||||
|
'REFLECT')
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'stride': self.stride,
|
||||||
|
'kernel_size': self.kernel_size}
|
||||||
|
base_config = super(ReflectionPadding2D, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
# Update layers into Keras custom objects
|
||||||
|
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||||
|
if inspect.isclass(obj) and obj.__module__ == __name__:
|
||||||
|
get_custom_objects().update({name: obj})
|
844
lib/model/losses.py
Normal file
844
lib/model/losses.py
Normal file
|
@ -0,0 +1,844 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Custom Loss Functions for faceswap.py
|
||||||
|
Losses from:
|
||||||
|
keras.contrib
|
||||||
|
dfaker: https://github.com/dfaker/df
|
||||||
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
|
||||||
|
import keras.backend as K
|
||||||
|
from keras.layers import Lambda, concatenate
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.contrib.distributions import Beta
|
||||||
|
|
||||||
|
from .normalization import InstanceNormalization
|
||||||
|
|
||||||
|
|
||||||
|
class DSSIMObjective():
|
||||||
|
""" DSSIM Loss Function
|
||||||
|
|
||||||
|
Code copy and pasted, with minor ammendments from:
|
||||||
|
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2017 Fariz Rahman
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE. """
|
||||||
|
# pylint: disable=C0103
|
||||||
|
def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0):
|
||||||
|
"""
|
||||||
|
Difference of Structural Similarity (DSSIM loss function). Clipped
|
||||||
|
between 0 and 0.5
|
||||||
|
Note : You should add a regularization term like a l2 loss in
|
||||||
|
addition to this one.
|
||||||
|
Note : In theano, the `kernel_size` must be a factor of the output
|
||||||
|
size. So 3 could not be the `kernel_size` for an output of 32.
|
||||||
|
# Arguments
|
||||||
|
k1: Parameter of the SSIM (default 0.01)
|
||||||
|
k2: Parameter of the SSIM (default 0.03)
|
||||||
|
kernel_size: Size of the sliding window (default 3)
|
||||||
|
max_value: Max value of the output (default 1.0)
|
||||||
|
"""
|
||||||
|
self.__name__ = 'DSSIMObjective'
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.k1 = k1
|
||||||
|
self.k2 = k2
|
||||||
|
self.max_value = max_value
|
||||||
|
self.c1 = (self.k1 * self.max_value) ** 2
|
||||||
|
self.c2 = (self.k2 * self.max_value) ** 2
|
||||||
|
self.dim_ordering = K.image_data_format()
|
||||||
|
self.backend = K.backend()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __int_shape(x):
|
||||||
|
return K.int_shape(x)
|
||||||
|
|
||||||
|
def __call__(self, y_true, y_pred):
|
||||||
|
# There are additional parameters for this function
|
||||||
|
# Note: some of the 'modes' for edge behavior do not yet have a
|
||||||
|
# gradient definition in the Theano tree and cannot be used for
|
||||||
|
# learning
|
||||||
|
|
||||||
|
kernel = [self.kernel_size, self.kernel_size]
|
||||||
|
y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:]))
|
||||||
|
y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:]))
|
||||||
|
|
||||||
|
patches_pred = self.extract_image_patches(y_pred,
|
||||||
|
kernel,
|
||||||
|
kernel,
|
||||||
|
'valid',
|
||||||
|
self.dim_ordering)
|
||||||
|
patches_true = self.extract_image_patches(y_true,
|
||||||
|
kernel,
|
||||||
|
kernel,
|
||||||
|
'valid',
|
||||||
|
self.dim_ordering)
|
||||||
|
|
||||||
|
# Reshape to get the var in the cells
|
||||||
|
_, w, h, c1, c2, c3 = self.__int_shape(patches_pred)
|
||||||
|
patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3])
|
||||||
|
patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3])
|
||||||
|
# Get mean
|
||||||
|
u_true = K.mean(patches_true, axis=-1)
|
||||||
|
u_pred = K.mean(patches_pred, axis=-1)
|
||||||
|
# Get variance
|
||||||
|
var_true = K.var(patches_true, axis=-1)
|
||||||
|
var_pred = K.var(patches_pred, axis=-1)
|
||||||
|
# Get std dev
|
||||||
|
covar_true_pred = K.mean(
|
||||||
|
patches_true * patches_pred, axis=-1) - u_true * u_pred
|
||||||
|
|
||||||
|
ssim = (2 * u_true * u_pred + self.c1) * (
|
||||||
|
2 * covar_true_pred + self.c2)
|
||||||
|
denom = (K.square(u_true) + K.square(u_pred) + self.c1) * (
|
||||||
|
var_pred + var_true + self.c2)
|
||||||
|
ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero
|
||||||
|
return K.mean((1.0 - ssim) / 2.0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preprocess_padding(padding):
|
||||||
|
"""Convert keras' padding to tensorflow's padding.
|
||||||
|
# Arguments
|
||||||
|
padding: string, `"same"` or `"valid"`.
|
||||||
|
# Returns
|
||||||
|
a string, `"SAME"` or `"VALID"`.
|
||||||
|
# Raises
|
||||||
|
ValueError: if `padding` is invalid.
|
||||||
|
"""
|
||||||
|
if padding == 'same':
|
||||||
|
padding = 'SAME'
|
||||||
|
elif padding == 'valid':
|
||||||
|
padding = 'VALID'
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid padding:', padding)
|
||||||
|
return padding
|
||||||
|
|
||||||
|
def extract_image_patches(self, x, ksizes, ssizes, padding='same',
|
||||||
|
data_format='channels_last'):
|
||||||
|
'''
|
||||||
|
Extract the patches from an image
|
||||||
|
# Parameters
|
||||||
|
x : The input image
|
||||||
|
ksizes : 2-d tuple with the kernel size
|
||||||
|
ssizes : 2-d tuple with the strides size
|
||||||
|
padding : 'same' or 'valid'
|
||||||
|
data_format : 'channels_last' or 'channels_first'
|
||||||
|
# Returns
|
||||||
|
The (k_w,k_h) patches extracted
|
||||||
|
TF ==> (batch_size,w,h,k_w,k_h,c)
|
||||||
|
TH ==> (batch_size,w,h,c,k_w,k_h)
|
||||||
|
'''
|
||||||
|
kernel = [1, ksizes[0], ksizes[1], 1]
|
||||||
|
strides = [1, ssizes[0], ssizes[1], 1]
|
||||||
|
padding = self._preprocess_padding(padding)
|
||||||
|
if data_format == 'channels_first':
|
||||||
|
x = K.permute_dimensions(x, (0, 2, 3, 1))
|
||||||
|
_, _, _, ch_i = K.int_shape(x)
|
||||||
|
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
|
||||||
|
padding)
|
||||||
|
# Reshaping to fit Theano
|
||||||
|
_, w, h, ch = K.int_shape(patches)
|
||||||
|
patches = tf.reshape(tf.transpose(tf.reshape(patches,
|
||||||
|
[-1, w, h,
|
||||||
|
tf.floordiv(ch, ch_i),
|
||||||
|
ch_i]),
|
||||||
|
[0, 1, 2, 4, 3]),
|
||||||
|
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
|
||||||
|
if data_format == 'channels_last':
|
||||||
|
patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
||||||
|
return patches
|
||||||
|
|
||||||
|
# <<< START: from Dfaker >>> #
|
||||||
|
class PenalizedLoss(): # pylint: disable=too-few-public-methods
|
||||||
|
""" Penalized Loss
|
||||||
|
from: https://github.com/dfaker/df """
|
||||||
|
def __init__(self, mask, loss_func, mask_prop=1.0):
|
||||||
|
self.mask = mask
|
||||||
|
self.loss_func = loss_func
|
||||||
|
self.mask_prop = mask_prop
|
||||||
|
self.mask_as_k_inv_prop = 1-mask_prop
|
||||||
|
|
||||||
|
def __call__(self, y_true, y_pred):
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
tro, tgo, tbo = tf.split(y_true, 3, 3)
|
||||||
|
pro, pgo, pbo = tf.split(y_pred, 3, 3)
|
||||||
|
|
||||||
|
tr = tro
|
||||||
|
tg = tgo
|
||||||
|
tb = tbo
|
||||||
|
|
||||||
|
pr = pro
|
||||||
|
pg = pgo
|
||||||
|
pb = pbo
|
||||||
|
m = self.mask
|
||||||
|
|
||||||
|
m = m * self.mask_prop
|
||||||
|
m += self.mask_as_k_inv_prop
|
||||||
|
tr *= m
|
||||||
|
tg *= m
|
||||||
|
tb *= m
|
||||||
|
|
||||||
|
pr *= m
|
||||||
|
pg *= m
|
||||||
|
pb *= m
|
||||||
|
|
||||||
|
y = tf.concat([tr, tg, tb], 3)
|
||||||
|
p = tf.concat([pr, pg, pb], 3)
|
||||||
|
|
||||||
|
# yo = tf.stack([tro,tgo,tbo],3)
|
||||||
|
# po = tf.stack([pro,pgo,pbo],3)
|
||||||
|
|
||||||
|
return self.loss_func(y, p)
|
||||||
|
# <<< END: from Dfaker >>> #
|
||||||
|
|
||||||
|
|
||||||
|
# <<< START: from Shoanlu GAN >>> #
|
||||||
|
def first_order(var_x, axis=1):
|
||||||
|
""" First Order Function from Shoanlu GAN """
|
||||||
|
img_nrows = var_x.shape[1]
|
||||||
|
img_ncols = var_x.shape[2]
|
||||||
|
if axis == 1:
|
||||||
|
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :])
|
||||||
|
if axis == 2:
|
||||||
|
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def calc_loss(pred, target, loss='l2'):
|
||||||
|
""" Calculate Loss from Shoanlu GAN """
|
||||||
|
if loss.lower() == "l2":
|
||||||
|
return K.mean(K.square(pred - target))
|
||||||
|
if loss.lower() == "l1":
|
||||||
|
return K.mean(K.abs(pred - target))
|
||||||
|
if loss.lower() == "cross_entropy":
|
||||||
|
return -K.mean(K.log(pred + K.epsilon()) * target +
|
||||||
|
K.log(1 - pred + K.epsilon()) * (1 - target))
|
||||||
|
raise ValueError('Recieve an unknown loss type: {}.'.format(loss))
|
||||||
|
|
||||||
|
|
||||||
|
def cyclic_loss(net_g1, net_g2, real1):
|
||||||
|
""" Cyclic Loss Function from Shoanlu GAN """
|
||||||
|
fake2 = net_g2(real1)[-1] # fake2 ABGR
|
||||||
|
fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR
|
||||||
|
cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR
|
||||||
|
cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR
|
||||||
|
loss = calc_loss(cyclic1, real1, loss='l1')
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):
|
||||||
|
""" Adversarial Loss Function from Shoanlu GAN """
|
||||||
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
||||||
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
||||||
|
fake = alpha * fake_bgr + (1-alpha) * distorted
|
||||||
|
|
||||||
|
if gan_training == "mixup_LSGAN":
|
||||||
|
dist = Beta(0.2, 0.2)
|
||||||
|
lam = dist.sample()
|
||||||
|
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
||||||
|
pred_fake = net_d(concatenate([fake, distorted]))
|
||||||
|
pred_mixup = net_d(mixup)
|
||||||
|
loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
|
||||||
|
loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2")
|
||||||
|
mixup2 = lam * concatenate([real,
|
||||||
|
distorted]) + (1 - lam) * concatenate([fake_bgr,
|
||||||
|
distorted])
|
||||||
|
pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
|
||||||
|
pred_mixup2 = net_d(mixup2)
|
||||||
|
loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
|
||||||
|
loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2")
|
||||||
|
elif gan_training == "relativistic_avg_LSGAN":
|
||||||
|
real_pred = net_d(concatenate([real, distorted]))
|
||||||
|
fake_pred = net_d(concatenate([fake, distorted]))
|
||||||
|
loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2
|
||||||
|
loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2
|
||||||
|
loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred)))
|
||||||
|
|
||||||
|
fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
|
||||||
|
loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
|
||||||
|
K.ones_like(fake_pred2)))/2
|
||||||
|
loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
|
||||||
|
K.zeros_like(fake_pred2)))/2
|
||||||
|
loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
|
||||||
|
K.zeros_like(fake_pred2)))/2
|
||||||
|
loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
|
||||||
|
K.ones_like(fake_pred2)))/2
|
||||||
|
else:
|
||||||
|
raise ValueError("Receive an unknown GAN training method: {gan_training}")
|
||||||
|
return loss_d, loss_g
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights):
|
||||||
|
""" Reconstruction Loss Function from Shoanlu GAN """
|
||||||
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
||||||
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
||||||
|
|
||||||
|
loss_g = 0
|
||||||
|
loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1")
|
||||||
|
loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real)))
|
||||||
|
|
||||||
|
for out in model_outputs[:-1]:
|
||||||
|
out_size = out.get_shape().as_list()
|
||||||
|
resized_real = tf.image.resize_images(real, out_size[1:3])
|
||||||
|
loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1")
|
||||||
|
return loss_g
|
||||||
|
|
||||||
|
|
||||||
|
def edge_loss(real, fake_abgr, mask_eyes, **weights):
|
||||||
|
""" Edge Loss Function from Shoanlu GAN """
|
||||||
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
||||||
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
||||||
|
|
||||||
|
loss_g = 0
|
||||||
|
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1),
|
||||||
|
first_order(real, axis=1), "l1")
|
||||||
|
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2),
|
||||||
|
first_order(real, axis=2), "l1")
|
||||||
|
shape_mask_eyes = mask_eyes.get_shape().as_list()
|
||||||
|
resized_mask_eyes = tf.image.resize_images(mask_eyes,
|
||||||
|
[shape_mask_eyes[1]-1, shape_mask_eyes[2]-1])
|
||||||
|
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
|
||||||
|
(first_order(fake_bgr, axis=1) -
|
||||||
|
first_order(real, axis=1))))
|
||||||
|
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
|
||||||
|
(first_order(fake_bgr, axis=2) -
|
||||||
|
first_order(real, axis=2))))
|
||||||
|
return loss_g
|
||||||
|
|
||||||
|
|
||||||
|
def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights):
|
||||||
|
""" Perceptual Loss Function from Shoanlu GAN """
|
||||||
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
||||||
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
||||||
|
fake = alpha * fake_bgr + (1-alpha) * distorted
|
||||||
|
|
||||||
|
def preprocess_vggface(var_x):
|
||||||
|
var_x = (var_x + 1) / 2 * 255 # channel order: BGR
|
||||||
|
var_x -= [91.4953, 103.8827, 131.0912]
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
real_sz224 = tf.image.resize_images(real, [224, 224])
|
||||||
|
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
||||||
|
dist = Beta(0.2, 0.2)
|
||||||
|
lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1.
|
||||||
|
mixup = lam*fake_bgr + (1-lam)*fake
|
||||||
|
fake_sz224 = tf.image.resize_images(mixup, [224, 224])
|
||||||
|
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
||||||
|
real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
|
||||||
|
fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)
|
||||||
|
|
||||||
|
# Apply instance norm on VGG(ResNet) features
|
||||||
|
# From MUNIT https://github.com/NVlabs/MUNIT
|
||||||
|
loss_g = 0
|
||||||
|
|
||||||
|
def instnorm():
|
||||||
|
return InstanceNormalization()
|
||||||
|
|
||||||
|
loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
|
||||||
|
instnorm()(real_feat7), "l2")
|
||||||
|
loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
|
||||||
|
instnorm()(real_feat28), "l2")
|
||||||
|
loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
|
||||||
|
instnorm()(real_feat55), "l2")
|
||||||
|
loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
|
||||||
|
instnorm()(real_feat112), "l2")
|
||||||
|
return loss_g
|
||||||
|
|
||||||
|
# <<< END: from Shoanlu GAN >>> #
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0):
|
||||||
|
'''
|
||||||
|
generalized function used to return a large variety of mathematical loss functions
|
||||||
|
primary benefit is smooth, differentiable version of L1 loss
|
||||||
|
|
||||||
|
Barron, J. A More General Robust Loss Function
|
||||||
|
https://arxiv.org/pdf/1701.03077.pdf
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
a: penalty factor. larger number give larger weight to large deviations
|
||||||
|
c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 )
|
||||||
|
|
||||||
|
Return:
|
||||||
|
a loss value from the results of function(y_pred - y_true)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss
|
||||||
|
a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss
|
||||||
|
'''
|
||||||
|
x = y_pred - y_true
|
||||||
|
loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 )
|
||||||
|
return K.mean(loss, axis=-1) * c
|
||||||
|
|
||||||
|
|
||||||
|
def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0):
|
||||||
|
h = c
|
||||||
|
w = c
|
||||||
|
x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0)
|
||||||
|
loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w))
|
||||||
|
loss += 1e-10
|
||||||
|
return K.mean(loss, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_loss(y_true, y_pred):
|
||||||
|
'''
|
||||||
|
Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions.
|
||||||
|
These gradients are then compared between the ground truth and the predicted image and the difference is taken.
|
||||||
|
The difference used is a smooth L1 norm ( approximate to MAE but differable at zero )
|
||||||
|
When used as a loss, its minimization will result in predicted images approaching the same level of sharpness
|
||||||
|
/ blurriness as the ground truth.
|
||||||
|
|
||||||
|
TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014
|
||||||
|
(http://downloads.hindawi.com/journals/mpe/2014/790547.pdf)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
y_true: The predicted frames at each scale.
|
||||||
|
y_true: The ground truth frames at each scale
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The GD loss.
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert 4 == K.ndim(y_true)
|
||||||
|
y_true.set_shape([None,80,80,3])
|
||||||
|
y_pred.set_shape([None,80,80,3])
|
||||||
|
TV_weight = 1.0
|
||||||
|
TV2_weight = 1.0
|
||||||
|
loss = 0.0
|
||||||
|
|
||||||
|
def diff_x(X):
|
||||||
|
Xleft = X[:, :, 1, :] - X[:, :, 0, :]
|
||||||
|
Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2)
|
||||||
|
Xright = X[:, :, -1, :] - X[:, :, -2, :]
|
||||||
|
Xout = [Xleft] + Xinner + [Xright]
|
||||||
|
Xout = tf.stack(Xout,axis=2)
|
||||||
|
return Xout * 0.5
|
||||||
|
|
||||||
|
def diff_y(X):
|
||||||
|
Xtop = X[:, 1, :, :] - X[:, 0, :, :]
|
||||||
|
Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1)
|
||||||
|
Xbot = X[:, -1, :, :] - X[:, -2, :, :]
|
||||||
|
Xout = [Xtop] + Xinner + [Xbot]
|
||||||
|
Xout = tf.stack(Xout,axis=1)
|
||||||
|
return Xout * 0.5
|
||||||
|
|
||||||
|
def diff_xx(X):
|
||||||
|
Xleft = X[:, :, 1, :] + X[:, :, 0, :]
|
||||||
|
Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2)
|
||||||
|
Xright = X[:, :, -1, :] + X[:, :, -2, :]
|
||||||
|
Xout = [Xleft] + Xinner + [Xright]
|
||||||
|
Xout = tf.stack(Xout,axis=2)
|
||||||
|
return Xout - 2.0 * X
|
||||||
|
|
||||||
|
def diff_yy(X):
|
||||||
|
Xtop = X[:, 1, :, :] + X[:, 0, :, :]
|
||||||
|
Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1)
|
||||||
|
Xbot = X[:, -1, :, :] + X[:, -2, :, :]
|
||||||
|
Xout = [Xtop] + Xinner + [Xbot]
|
||||||
|
Xout = tf.stack(Xout,axis=1)
|
||||||
|
return Xout - 2.0 * X
|
||||||
|
|
||||||
|
def diff_xy(X):
|
||||||
|
#xout1
|
||||||
|
top_left = X[:, 1, 1, :]+X[:, 0, 0, :]
|
||||||
|
inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1)
|
||||||
|
bot_left = X[:, -1, 1, :]+X[:, -2, 0, :]
|
||||||
|
X_left = [top_left] + inner_left + [bot_left]
|
||||||
|
X_left = tf.stack(X_left, axis=1)
|
||||||
|
|
||||||
|
top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :]
|
||||||
|
mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1)
|
||||||
|
bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :]
|
||||||
|
X_mid = [top_mid] + mid_mid + [bot_mid]
|
||||||
|
X_mid = tf.stack(X_mid, axis=1)
|
||||||
|
|
||||||
|
top_right = X[:, 1, -1, :]+X[:, 0, -2, :]
|
||||||
|
inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1)
|
||||||
|
bot_right = X[:, -1, -1, :]+X[:, -2, -2, :]
|
||||||
|
X_right = [top_right] + inner_right + [bot_right]
|
||||||
|
X_right = tf.stack(X_right, axis=1)
|
||||||
|
|
||||||
|
X_mid = tf.unstack(X_mid, axis=2)
|
||||||
|
Xout1 = [X_left] + X_mid + [X_right]
|
||||||
|
Xout1 = tf.stack(Xout1, axis=2)
|
||||||
|
|
||||||
|
#Xout2
|
||||||
|
top_left = X[:, 0, 1, :]+X[:, 1, 0, :]
|
||||||
|
inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1)
|
||||||
|
bot_left = X[:, -2, 1, :]+X[:, -1, 0, :]
|
||||||
|
X_left = [top_left] + inner_left + [bot_left]
|
||||||
|
X_left = tf.stack(X_left, axis=1)
|
||||||
|
|
||||||
|
top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :]
|
||||||
|
mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1)
|
||||||
|
bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :]
|
||||||
|
X_mid = [top_mid] + mid_mid + [bot_mid]
|
||||||
|
X_mid = tf.stack(X_mid, axis=1)
|
||||||
|
|
||||||
|
top_right = X[:, 0, -1, :]+X[:, 1, -2, :]
|
||||||
|
inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1)
|
||||||
|
bot_right = X[:, -2, -1, :]+X[:, -1, -2, :]
|
||||||
|
X_right = [top_right] + inner_right + [bot_right]
|
||||||
|
X_right = tf.stack(X_right, axis=1)
|
||||||
|
|
||||||
|
X_mid = tf.unstack(X_mid, axis=2)
|
||||||
|
Xout2 = [X_left] + X_mid + [X_right]
|
||||||
|
Xout2 = tf.stack(Xout2, axis=2)
|
||||||
|
|
||||||
|
return (Xout1 - Xout2) * 0.25
|
||||||
|
|
||||||
|
loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) +
|
||||||
|
generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) )
|
||||||
|
|
||||||
|
loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) +
|
||||||
|
generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) +
|
||||||
|
2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) )
|
||||||
|
|
||||||
|
return loss / ( TV_weight + TV2_weight )
|
||||||
|
|
||||||
|
|
||||||
|
def scharr_edges(image, magnitude):
|
||||||
|
'''
|
||||||
|
Returns a tensor holding modified Scharr edge maps.
|
||||||
|
Arguments:
|
||||||
|
image: Image tensor with shape [batch_size, h, w, d] and type float32.
|
||||||
|
The image(s) must be 2x2 or larger.
|
||||||
|
magnitude: Boolean to determine if the edge magnitude or edge direction is returned
|
||||||
|
Returns:
|
||||||
|
Tensor holding edge maps for each channel. Returns a tensor with shape
|
||||||
|
[batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]],
|
||||||
|
[dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter.
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Define vertical and horizontal Scharr filters.
|
||||||
|
static_image_shape = image.get_shape()
|
||||||
|
image_shape = tf.shape(image)
|
||||||
|
'''
|
||||||
|
#modified 3x3 Scharr
|
||||||
|
kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]],
|
||||||
|
[[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]]
|
||||||
|
'''
|
||||||
|
# 5x5 Scharr
|
||||||
|
kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]],
|
||||||
|
[[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]]
|
||||||
|
num_kernels = len(kernels)
|
||||||
|
kernels = numpy.transpose(numpy.asarray(kernels), (1, 2, 0))
|
||||||
|
kernels = numpy.expand_dims(kernels, -2) / numpy.sum(numpy.abs(kernels))
|
||||||
|
kernels_tf = tf.constant(kernels, dtype=image.dtype)
|
||||||
|
kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters')
|
||||||
|
|
||||||
|
# Use depth-wise convolution to calculate edge maps per channel.
|
||||||
|
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
|
||||||
|
padded = tf.pad(image, pad_sizes, mode='REFLECT')
|
||||||
|
|
||||||
|
# Output tensor has shape [batch_size, h, w, d * num_kernels].
|
||||||
|
strides = [1, 1, 1, 1]
|
||||||
|
output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID')
|
||||||
|
|
||||||
|
# Reshape to [batch_size, h, w, d, num_kernels].
|
||||||
|
shape = tf.concat([image_shape, [num_kernels]], 0)
|
||||||
|
output = tf.reshape(output, shape=shape)
|
||||||
|
output.set_shape(static_image_shape.concatenate([num_kernels]))
|
||||||
|
|
||||||
|
if magnitude: # magnitude of edges
|
||||||
|
output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1))
|
||||||
|
else: # direction of edges
|
||||||
|
output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1])))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def gmsd_loss(y_true,y_pred):
|
||||||
|
'''
|
||||||
|
Improved image quality metric over MS-SSIM with easier calc
|
||||||
|
http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm
|
||||||
|
https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf
|
||||||
|
'''
|
||||||
|
true_edge_mag = scharr_edges(y_true,True)
|
||||||
|
pred_edge_mag = scharr_edges(y_pred,True)
|
||||||
|
c = 0.002
|
||||||
|
upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c
|
||||||
|
lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c
|
||||||
|
GMS = tf.div(upper,lower)
|
||||||
|
_mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True)
|
||||||
|
GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1]
|
||||||
|
return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded
|
||||||
|
|
||||||
|
|
||||||
|
def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)):
|
||||||
|
'''
|
||||||
|
Computes the MS-SSIM between img1 and img2.
|
||||||
|
This function assumes that `img1` and `img2` are image batches, i.e. the last
|
||||||
|
three dimensions are [height, width, channels].
|
||||||
|
Note: The true SSIM is only defined on grayscale. This function does not
|
||||||
|
perform any colorspace transform. (If input is already YUV, then it will
|
||||||
|
compute YUV SSIM average.)
|
||||||
|
Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
|
||||||
|
structural similarity for image quality assessment." Signals, Systems and
|
||||||
|
Computers, 2004.
|
||||||
|
Arguments:
|
||||||
|
img1: First image batch.
|
||||||
|
img2: Second image batch. Must have the same rank as img1.
|
||||||
|
max_val: The dynamic range of the images (i.e., the difference between the
|
||||||
|
maximum the and minimum allowed values).
|
||||||
|
power_factors: Iterable of weights for each of the scales. The number of
|
||||||
|
scales used is the length of the list. Index 0 is the unscaled
|
||||||
|
resolution's weight and each increasing scale corresponds to the image
|
||||||
|
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
|
||||||
|
0.1333), which are the values obtained in the original paper.
|
||||||
|
Returns:
|
||||||
|
A tensor containing an MS-SSIM value for each image in batch. The values
|
||||||
|
are in range [0, 1]. Returns a tensor with shape:
|
||||||
|
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _verify_compatible_image_shapes(img1, img2):
|
||||||
|
'''
|
||||||
|
Checks if two image tensors are compatible for applying SSIM or PSNR.
|
||||||
|
This function checks if two sets of images have ranks at least 3, and if the
|
||||||
|
last three dimensions match.
|
||||||
|
Args:
|
||||||
|
img1: Tensor containing the first image batch.
|
||||||
|
img2: Tensor containing the second image batch.
|
||||||
|
Returns:
|
||||||
|
A tuple containing: the first tensor shape, the second tensor shape, and a
|
||||||
|
list of control_flow_ops.Assert() ops implementing the checks.
|
||||||
|
Raises:
|
||||||
|
ValueError: When static shape check fails.
|
||||||
|
'''
|
||||||
|
shape1 = img1.get_shape().with_rank_at_least(3)
|
||||||
|
shape2 = img2.get_shape().with_rank_at_least(3)
|
||||||
|
shape1[-3:].assert_is_compatible_with(shape2[-3:])
|
||||||
|
|
||||||
|
if shape1.ndims is not None and shape2.ndims is not None:
|
||||||
|
for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])):
|
||||||
|
if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
|
||||||
|
raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2))
|
||||||
|
|
||||||
|
# Now assign shape tensors.
|
||||||
|
shape1, shape2 = tf.shape_n([img1, img2])
|
||||||
|
|
||||||
|
# TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable.
|
||||||
|
checks = []
|
||||||
|
checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10))
|
||||||
|
checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10))
|
||||||
|
|
||||||
|
return shape1, shape2, checks
|
||||||
|
|
||||||
|
def _ssim_per_channel(img1, img2, max_val=1.0):
|
||||||
|
'''
|
||||||
|
Computes SSIM index between img1 and img2 per color channel.
|
||||||
|
This function matches the standard SSIM implementation from:
|
||||||
|
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
|
||||||
|
quality assessment: from error visibility to structural similarity. IEEE
|
||||||
|
transactions on image processing.
|
||||||
|
Details:
|
||||||
|
- 11x11 Gaussian filter of width 1.5 is used.
|
||||||
|
- k1 = 0.01, k2 = 0.03 as in the original paper.
|
||||||
|
Args:
|
||||||
|
img1: First image batch.
|
||||||
|
img2: Second image batch.
|
||||||
|
max_val: The dynamic range of the images (i.e., the difference between the
|
||||||
|
maximum the and minimum allowed values).
|
||||||
|
Returns:
|
||||||
|
A pair of tensors containing and channel-wise SSIM and contrast-structure
|
||||||
|
values. The shape is [..., channels].
|
||||||
|
'''
|
||||||
|
|
||||||
|
def _fspecial_gauss(size, sigma):
|
||||||
|
'''
|
||||||
|
Function to mimic the 'fspecial' gaussian MATLAB function.
|
||||||
|
'''
|
||||||
|
size = tf.convert_to_tensor(size, 'int32')
|
||||||
|
sigma = tf.convert_to_tensor(sigma)
|
||||||
|
|
||||||
|
coords = tf.cast(tf.range(size), sigma.dtype)
|
||||||
|
coords -= tf.cast(size - 1, sigma.dtype) / 2.0
|
||||||
|
|
||||||
|
g = tf.square(coords)
|
||||||
|
g *= -0.5 / tf.square(sigma)
|
||||||
|
|
||||||
|
g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1])
|
||||||
|
g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
|
||||||
|
g = tf.nn.softmax(g)
|
||||||
|
return tf.reshape(g, shape=[size, size, 1, 1])
|
||||||
|
|
||||||
|
def _ssim_helper(x, y, max_val, kernel, compensation=1.0):
|
||||||
|
'''
|
||||||
|
Helper function for computing SSIM.
|
||||||
|
SSIM estimates covariances with weighted sums. The default parameters
|
||||||
|
use a biased estimate of the covariance:
|
||||||
|
Suppose `reducer` is a weighted sum, then the mean estimators are
|
||||||
|
\mu_x = \sum_i w_i x_i,
|
||||||
|
\mu_y = \sum_i w_i y_i,
|
||||||
|
where w_i's are the weighted-sum weights, and covariance estimator is
|
||||||
|
cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
||||||
|
with assumption \sum_i w_i = 1. This covariance estimator is biased, since
|
||||||
|
E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
|
||||||
|
For SSIM measure with unbiased covariance estimators, pass as `compensation`
|
||||||
|
argument (1 - \sum_i w_i ^ 2).
|
||||||
|
Arguments:
|
||||||
|
x: First set of images.
|
||||||
|
y: Second set of images.
|
||||||
|
reducer: Function that computes 'local' averages from set of images.
|
||||||
|
For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]),
|
||||||
|
and for convolutional version, this is usually tf.nn.avg_pool or
|
||||||
|
tf.nn.conv2d with weighted-sum kernel.
|
||||||
|
max_val: The dynamic range (i.e., the difference between the maximum
|
||||||
|
possible allowed value and the minimum allowed value).
|
||||||
|
compensation: Compensation factor. See above.
|
||||||
|
Returns:
|
||||||
|
A pair containing the luminance measure, and the contrast-structure measure.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def reducer(x, kernel):
|
||||||
|
shape = tf.shape(x)
|
||||||
|
x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0))
|
||||||
|
y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
|
||||||
|
return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0))
|
||||||
|
|
||||||
|
_SSIM_K1 = 0.01
|
||||||
|
_SSIM_K2 = 0.03
|
||||||
|
|
||||||
|
c1 = (_SSIM_K1 * max_val) ** 2
|
||||||
|
c2 = (_SSIM_K2 * max_val) ** 2
|
||||||
|
|
||||||
|
# SSIM luminance measure is
|
||||||
|
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
|
||||||
|
mean0 = reducer(x, kernel)
|
||||||
|
mean1 = reducer(y, kernel)
|
||||||
|
num0 = mean0 * mean1 * 2.0
|
||||||
|
den0 = tf.square(mean0) + tf.square(mean1)
|
||||||
|
luminance = (num0 + c1) / (den0 + c1)
|
||||||
|
|
||||||
|
# SSIM contrast-structure measure is
|
||||||
|
# (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
|
||||||
|
# Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
|
||||||
|
# cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
||||||
|
# = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
|
||||||
|
num1 = reducer(x * y, kernel) * 2.0
|
||||||
|
den1 = reducer(tf.square(x) + tf.square(y), kernel)
|
||||||
|
c2 *= compensation
|
||||||
|
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
||||||
|
|
||||||
|
# SSIM score is the product of the luminance and contrast-structure measures.
|
||||||
|
return luminance, cs
|
||||||
|
|
||||||
|
filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due
|
||||||
|
filter_sigma = tf.constant(1.5, dtype=img1.dtype)
|
||||||
|
|
||||||
|
shape1, shape2 = tf.shape_n([img1, img2])
|
||||||
|
checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8),
|
||||||
|
tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)]
|
||||||
|
|
||||||
|
# Enforce the check to run before computation.
|
||||||
|
with tf.control_dependencies(checks):
|
||||||
|
img1 = tf.identity(img1)
|
||||||
|
|
||||||
|
# TODO(sjhwang): Try to cache kernels and compensation factor.
|
||||||
|
kernel = _fspecial_gauss(filter_size, filter_sigma)
|
||||||
|
kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1])
|
||||||
|
|
||||||
|
# The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
|
||||||
|
# but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
|
||||||
|
compensation = 1.0
|
||||||
|
|
||||||
|
# TODO(sjhwang): Try FFT.
|
||||||
|
# TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
|
||||||
|
# 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter.
|
||||||
|
|
||||||
|
luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation)
|
||||||
|
|
||||||
|
# Average over the second and the third from the last: height, width.
|
||||||
|
axes = tf.constant([-3, -2], dtype='int32')
|
||||||
|
ssim_val = tf.reduce_mean(luminance * cs, axes)
|
||||||
|
cs = tf.reduce_mean(cs, axes)
|
||||||
|
return ssim_val, cs
|
||||||
|
|
||||||
|
def do_pad(images, remainder):
|
||||||
|
padding = tf.expand_dims(remainder, -1)
|
||||||
|
padding = tf.pad(padding, [[1, 0], [1, 0]])
|
||||||
|
return [tf.pad(x, padding, mode='SYMMETRIC') for x in images]
|
||||||
|
|
||||||
|
# Shape checking.
|
||||||
|
shape1 = img1.get_shape().with_rank_at_least(3)
|
||||||
|
shape2 = img2.get_shape().with_rank_at_least(3)
|
||||||
|
shape1[-3:].merge_with(shape2[-3:])
|
||||||
|
|
||||||
|
with tf.name_scope(None, 'MS-SSIM', [img1, img2]):
|
||||||
|
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
||||||
|
with tf.control_dependencies(checks):
|
||||||
|
img1 = tf.identity(img1)
|
||||||
|
|
||||||
|
# Need to convert the images to float32. Scale max_val accordingly so that
|
||||||
|
# SSIM is computed correctly.
|
||||||
|
max_val = tf.cast(max_val, img1.dtype)
|
||||||
|
max_val = tf.image.convert_image_dtype(max_val, 'float32')
|
||||||
|
img1 = tf.image.convert_image_dtype(img1, 'float32')
|
||||||
|
img2 = tf.image.convert_image_dtype(img2, 'float32')
|
||||||
|
|
||||||
|
imgs = [img1, img2]
|
||||||
|
shapes = [shape1, shape2]
|
||||||
|
|
||||||
|
# img1 and img2 are assumed to be a (multi-dimensional) batch of
|
||||||
|
# 3-dimensional images (height, width, channels). `heads` contain the batch
|
||||||
|
# dimensions, and `tails` contain the image dimensions.
|
||||||
|
heads = [s[:-3] for s in shapes]
|
||||||
|
tails = [s[-3:] for s in shapes]
|
||||||
|
|
||||||
|
divisor = [1, 2, 2, 1]
|
||||||
|
divisor_tensor = tf.constant(divisor[1:], dtype='int32')
|
||||||
|
|
||||||
|
mcs = []
|
||||||
|
for k in range(len(power_factors)):
|
||||||
|
with tf.name_scope(None, 'Scale%d' % k, imgs):
|
||||||
|
if k > 0:
|
||||||
|
# Avg pool takes rank 4 tensors. Flatten leading dimensions.
|
||||||
|
flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)]
|
||||||
|
|
||||||
|
remainder = tails[0] % divisor_tensor
|
||||||
|
need_padding = tf.reduce_any(tf.not_equal(remainder, 0))
|
||||||
|
padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder),
|
||||||
|
lambda: flat_imgs)
|
||||||
|
|
||||||
|
downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID')
|
||||||
|
for x in padded]
|
||||||
|
tails = [x[1:] for x in tf.shape_n(downscaled)]
|
||||||
|
imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)]
|
||||||
|
|
||||||
|
# Overwrite previous ssim value since we only need the last one.
|
||||||
|
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val)
|
||||||
|
mcs.append(tf.nn.relu(cs))
|
||||||
|
|
||||||
|
# Remove the cs score for the last scale. In the MS-SSIM calculation,
|
||||||
|
# we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
|
||||||
|
mcs.pop() # Remove the cs score for the last scale.
|
||||||
|
mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1)
|
||||||
|
# Take weighted geometric mean across the scale axis.
|
||||||
|
ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1])
|
||||||
|
|
||||||
|
return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
|
||||||
|
|
||||||
|
|
||||||
|
def ms_ssim_loss(y_true,y_pred):
|
||||||
|
MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1)
|
||||||
|
return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded
|
101
lib/model/masks.py
Normal file
101
lib/model/masks.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Masks functions for faceswap.py
|
||||||
|
Masks from:
|
||||||
|
dfaker: https://github.com/dfaker/df"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lib.umeyama import umeyama
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def dfaker(landmarks, face, channels=4):
|
||||||
|
""" Dfaker model mask
|
||||||
|
Embeds the mask into the face alpha channel
|
||||||
|
|
||||||
|
channels: 1, 3 or 4:
|
||||||
|
1 - Return a single channel mask
|
||||||
|
3 - Return a 3 channel mask
|
||||||
|
4 - Return the original image with the mask in the alpha channel
|
||||||
|
"""
|
||||||
|
padding = int(face.shape[0] * 0.1875)
|
||||||
|
coverage = face.shape[0] - (padding * 2)
|
||||||
|
logger.trace("face_shape: %s, coverage: %s, landmarks: %s", face.shape, coverage, landmarks)
|
||||||
|
|
||||||
|
mat = umeyama(landmarks[17:], True)[0:2]
|
||||||
|
mat = np.array(mat.ravel()).reshape(2, 3)
|
||||||
|
mat = mat * coverage
|
||||||
|
mat[:, 2] += padding
|
||||||
|
|
||||||
|
points = np.array(landmarks).reshape((-1, 2))
|
||||||
|
facepoints = np.array(points).reshape((-1, 2))
|
||||||
|
|
||||||
|
mask = np.zeros_like(face, dtype=np.uint8)
|
||||||
|
|
||||||
|
hull = cv2.convexHull(facepoints.astype(int)) # pylint: disable=no-member
|
||||||
|
hull = cv2.transform(hull.reshape(1, -1, 2), # pylint: disable=no-member
|
||||||
|
mat).reshape(-1, 2).astype(int)
|
||||||
|
cv2.fillConvexPoly(mask, hull, (255, 255, 255)) # pylint: disable=no-member
|
||||||
|
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # pylint: disable=no-member
|
||||||
|
mask = cv2.dilate(mask, # pylint: disable=no-member
|
||||||
|
kernel,
|
||||||
|
iterations=1,
|
||||||
|
borderType=cv2.BORDER_REFLECT) # pylint: disable=no-member
|
||||||
|
mask = mask[:, :, :1]
|
||||||
|
|
||||||
|
return merge_mask(face, mask, channels)
|
||||||
|
|
||||||
|
|
||||||
|
def dfl_full(landmarks, face, channels=4):
|
||||||
|
""" DFL Face Full Mask
|
||||||
|
|
||||||
|
channels: 1, 3 or 4:
|
||||||
|
1 - Return a single channel mask
|
||||||
|
3 - Return a 3 channel mask
|
||||||
|
4 - Return the original image with the mask in the alpha channel
|
||||||
|
"""
|
||||||
|
logger.trace("face_shape: %s, landmarks: %s", face.shape, landmarks)
|
||||||
|
mask = np.zeros(face.shape[0:2] + (1, ), dtype=np.float32)
|
||||||
|
jaw = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
|
||||||
|
landmarks[0:17], # jawline
|
||||||
|
landmarks[48:68], # mouth
|
||||||
|
[landmarks[0]], # temple
|
||||||
|
[landmarks[8]], # chin
|
||||||
|
[landmarks[16]]))) # temple
|
||||||
|
nose_ridge = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
|
||||||
|
landmarks[27:31], # nose line
|
||||||
|
[landmarks[33]]))) # nose point
|
||||||
|
eyes = cv2.convexHull(np.concatenate(( # pylint: disable=no-member
|
||||||
|
landmarks[17:27], # eyebrows
|
||||||
|
[landmarks[0]], # temple
|
||||||
|
[landmarks[27]], # nose top
|
||||||
|
[landmarks[16]], # temple
|
||||||
|
[landmarks[33]]))) # nose point
|
||||||
|
|
||||||
|
cv2.fillConvexPoly(mask, jaw, (255, 255, 255)) # pylint: disable=no-member
|
||||||
|
cv2.fillConvexPoly(mask, nose_ridge, (255, 255, 255)) # pylint: disable=no-member
|
||||||
|
cv2.fillConvexPoly(mask, eyes, (255, 255, 255)) # pylint: disable=no-member
|
||||||
|
return merge_mask(face, mask, channels)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_mask(image, mask, channels):
|
||||||
|
""" Return the mask in requested shape """
|
||||||
|
logger.trace("image_shape: %s, mask_shape: %s, channels: %s",
|
||||||
|
image.shape, mask.shape, channels)
|
||||||
|
assert channels in (1, 3, 4), "Channels should be 1, 3 or 4"
|
||||||
|
assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel"
|
||||||
|
|
||||||
|
if channels == 3:
|
||||||
|
retval = np.tile(mask, 3)
|
||||||
|
elif channels == 4:
|
||||||
|
retval = np.concatenate((image, mask), -1)
|
||||||
|
else:
|
||||||
|
retval = mask
|
||||||
|
|
||||||
|
logger.trace("Final mask shape: %s", retval.shape)
|
||||||
|
return retval
|
279
lib/model/nn_blocks.py
Normal file
279
lib/model/nn_blocks.py
Normal file
|
@ -0,0 +1,279 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Neural Network Blocks for faceswap.py
|
||||||
|
Blocks from:
|
||||||
|
the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
||||||
|
dfaker: https://github.com/dfaker/df
|
||||||
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
import keras.backend as K
|
||||||
|
|
||||||
|
from keras.layers import (add, Add, BatchNormalization, concatenate, Lambda, regularizers,
|
||||||
|
Permute, Reshape, SeparableConv2D, Softmax, UpSampling2D)
|
||||||
|
from keras.layers.advanced_activations import LeakyReLU
|
||||||
|
from keras.layers.convolutional import Conv2D
|
||||||
|
from keras.layers.core import Activation
|
||||||
|
from keras.initializers import he_uniform, Constant
|
||||||
|
from .initializers import ICNR
|
||||||
|
from .layers import PixelShuffler, Scale, SubPixelUpscaling, ReflectionPadding2D
|
||||||
|
from .normalization import GroupNormalization, InstanceNormalization
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class NNBlocks():
|
||||||
|
""" Blocks to use for creating models """
|
||||||
|
def __init__(self, use_subpixel=False, use_icnr_init=False, use_reflect_padding=False):
|
||||||
|
logger.debug("Initializing %s: (use_subpixel: %s, use_icnr_init: %s, use_reflect_padding: %s",
|
||||||
|
self.__class__.__name__, use_subpixel, use_icnr_init, use_reflect_padding)
|
||||||
|
self.use_subpixel = use_subpixel
|
||||||
|
self.use_icnr_init = use_icnr_init
|
||||||
|
self.use_reflect_padding = use_reflect_padding
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_kwargs(kwargs):
|
||||||
|
""" Set the default kernel initializer to he_uniform() """
|
||||||
|
kwargs["kernel_initializer"] = kwargs.get("kernel_initializer", he_uniform())
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
# <<< Original Model Blocks >>> #
|
||||||
|
def conv(self, inp, filters, kernel_size=5, strides=2, padding='same', use_instance_norm=False, res_block_follows=False, **kwargs):
|
||||||
|
""" Convolution Layer"""
|
||||||
|
logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, use_instance_norm: %s, "
|
||||||
|
"kwargs: %s", inp, filters, kernel_size, strides, use_instance_norm, kwargs)
|
||||||
|
kwargs = self.update_kwargs(kwargs)
|
||||||
|
if self.use_reflect_padding:
|
||||||
|
inp = ReflectionPadding2D(stride=strides, kernel_size=kernel_size)(inp)
|
||||||
|
padding = 'valid'
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
**kwargs)(inp)
|
||||||
|
if use_instance_norm:
|
||||||
|
var_x = InstanceNormalization()(var_x)
|
||||||
|
if not res_block_follows:
|
||||||
|
var_x = LeakyReLU(0.1)(var_x)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
def upscale(self, inp, filters, kernel_size=3, padding= 'same', use_instance_norm=False, res_block_follows=False, **kwargs):
|
||||||
|
""" Upscale Layer """
|
||||||
|
logger.debug("inp: %s, filters: %s, kernel_size: %s, use_instance_norm: %s, kwargs: %s",
|
||||||
|
inp, filters, kernel_size, use_instance_norm, kwargs)
|
||||||
|
kwargs = self.update_kwargs(kwargs)
|
||||||
|
if self.use_reflect_padding:
|
||||||
|
inp = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(inp)
|
||||||
|
padding = 'valid'
|
||||||
|
if self.use_icnr_init:
|
||||||
|
kwargs["kernel_initializer"] = ICNR(initializer=kwargs["kernel_initializer"])
|
||||||
|
var_x = Conv2D(filters * 4,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
**kwargs)(inp)
|
||||||
|
if use_instance_norm:
|
||||||
|
var_x = InstanceNormalization()(var_x)
|
||||||
|
if not res_block_follows:
|
||||||
|
var_x = LeakyReLU(0.1)(var_x)
|
||||||
|
if self.use_subpixel:
|
||||||
|
var_x = SubPixelUpscaling()(var_x)
|
||||||
|
else:
|
||||||
|
var_x = PixelShuffler()(var_x)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
# <<< DFaker Model Blocks >>> #
|
||||||
|
def res_block(self, inp, filters, kernel_size=3, padding= 'same', **kwargs):
|
||||||
|
""" Residual block """
|
||||||
|
logger.debug("inp: %s, filters: %s, kernel_size: %s, kwargs: %s",
|
||||||
|
inp, filters, kernel_size, kwargs)
|
||||||
|
kwargs = self.update_kwargs(kwargs)
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(inp)
|
||||||
|
if self.use_reflect_padding:
|
||||||
|
var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||||
|
padding = 'valid'
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
**kwargs)(var_x)
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
if self.use_reflect_padding:
|
||||||
|
var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||||
|
padding = 'valid'
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
**kwargs)(var_x)
|
||||||
|
var_x = Scale(gamma_init=Constant(value=0.1))(var_x)
|
||||||
|
var_x = Add()([var_x, inp])
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
# <<< Unbalanced Model Blocks >>> #
|
||||||
|
def conv_sep(self, inp, filters, kernel_size=5, strides=2, **kwargs):
|
||||||
|
""" Seperable Convolution Layer """
|
||||||
|
logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s",
|
||||||
|
inp, filters, kernel_size, strides, kwargs)
|
||||||
|
kwargs = self.update_kwargs(kwargs)
|
||||||
|
var_x = SeparableConv2D(filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
strides=strides,
|
||||||
|
padding='same',
|
||||||
|
**kwargs)(inp)
|
||||||
|
var_x = Activation("relu")(var_x)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
# <<< GAN V2.2 Blocks >>> #
|
||||||
|
# TODO Merge these into NNBLock class when porting GAN2.2
|
||||||
|
|
||||||
|
|
||||||
|
# Gan Constansts:
|
||||||
|
GAN22_CONV_INIT = "he_normal"
|
||||||
|
GAN22_REGULARIZER = 1e-4
|
||||||
|
|
||||||
|
|
||||||
|
# Gan Blocks:
|
||||||
|
def normalization(inp, norm='none', group='16'):
|
||||||
|
""" GAN Normalization """
|
||||||
|
if norm == 'layernorm':
|
||||||
|
var_x = GroupNormalization(group=group)(inp)
|
||||||
|
elif norm == 'batchnorm':
|
||||||
|
var_x = BatchNormalization()(inp)
|
||||||
|
elif norm == 'groupnorm':
|
||||||
|
var_x = GroupNormalization(group=16)(inp)
|
||||||
|
elif norm == 'instancenorm':
|
||||||
|
var_x = InstanceNormalization()(inp)
|
||||||
|
elif norm == 'hybrid':
|
||||||
|
if group % 2 == 1:
|
||||||
|
raise ValueError("Output channels must be an even number for hybrid norm, "
|
||||||
|
"received {}.".format(group))
|
||||||
|
filt = group
|
||||||
|
var_x_0 = Lambda(lambda var_x: var_x[..., :filt // 2])(var_x)
|
||||||
|
var_x_1 = Lambda(lambda var_x: var_x[..., filt // 2:])(var_x)
|
||||||
|
var_x_0 = Conv2D(filt // 2,
|
||||||
|
kernel_size=1,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=GAN22_CONV_INIT)(var_x_0)
|
||||||
|
var_x_1 = InstanceNormalization()(var_x_1)
|
||||||
|
var_x = concatenate([var_x_0, var_x_1], axis=-1)
|
||||||
|
else:
|
||||||
|
var_x = inp
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_ps(inp, filters, initializer, use_norm=False, norm="none"):
|
||||||
|
""" GAN Upscaler - Pixel Shuffler """
|
||||||
|
var_x = Conv2D(filters * 4,
|
||||||
|
kernel_size=3,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=initializer,
|
||||||
|
padding="same")(inp)
|
||||||
|
var_x = LeakyReLU(0.2)(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
var_x = PixelShuffler()(var_x)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_nn(inp, filters, use_norm=False, norm="none"):
|
||||||
|
""" GAN Neural Network """
|
||||||
|
var_x = UpSampling2D()(inp)
|
||||||
|
var_x = reflect_padding_2d(var_x, 1)
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=3,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer="he_normal")(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def reflect_padding_2d(inp, pad=1):
|
||||||
|
""" GAN Reflect Padding (2D) """
|
||||||
|
var_x = Lambda(lambda var_x: tf.pad(var_x,
|
||||||
|
[[0, 0], [pad, pad], [pad, pad], [0, 0]],
|
||||||
|
mode="REFLECT"))(inp)
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def conv_gan(inp, filters, use_norm=False, strides=2, norm='none'):
|
||||||
|
""" GAN Conv Block """
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=3,
|
||||||
|
strides=strides,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=GAN22_CONV_INIT,
|
||||||
|
use_bias=False,
|
||||||
|
padding="same")(inp)
|
||||||
|
var_x = Activation("relu")(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def conv_d_gan(inp, filters, use_norm=False, norm='none'):
|
||||||
|
""" GAN Discriminator Conv Block """
|
||||||
|
var_x = inp
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=4,
|
||||||
|
strides=2,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=GAN22_CONV_INIT,
|
||||||
|
use_bias=False,
|
||||||
|
padding="same")(var_x)
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def res_block_gan(inp, filters, use_norm=False, norm='none'):
|
||||||
|
""" GAN Res Block """
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=3,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=GAN22_CONV_INIT,
|
||||||
|
use_bias=False,
|
||||||
|
padding="same")(inp)
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
var_x = Conv2D(filters,
|
||||||
|
kernel_size=3,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER),
|
||||||
|
kernel_initializer=GAN22_CONV_INIT,
|
||||||
|
use_bias=False,
|
||||||
|
padding="same")(var_x)
|
||||||
|
var_x = add([var_x, inp])
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
var_x = normalization(var_x, norm, filters) if use_norm else var_x
|
||||||
|
return var_x
|
||||||
|
|
||||||
|
|
||||||
|
def self_attn_block(inp, n_c, squeeze_factor=8):
|
||||||
|
""" GAN Self Attention Block
|
||||||
|
Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow
|
||||||
|
"""
|
||||||
|
msg = "Input channels must be >= {}, recieved nc={}".format(squeeze_factor, n_c)
|
||||||
|
assert n_c // squeeze_factor > 0, msg
|
||||||
|
var_x = inp
|
||||||
|
shape_x = var_x.get_shape().as_list()
|
||||||
|
|
||||||
|
var_f = Conv2D(n_c // squeeze_factor, 1,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
|
||||||
|
var_g = Conv2D(n_c // squeeze_factor, 1,
|
||||||
|
kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
|
||||||
|
var_h = Conv2D(n_c, 1, kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x)
|
||||||
|
|
||||||
|
shape_f = var_f.get_shape().as_list()
|
||||||
|
shape_g = var_g.get_shape().as_list()
|
||||||
|
shape_h = var_h.get_shape().as_list()
|
||||||
|
flat_f = Reshape((-1, shape_f[-1]))(var_f)
|
||||||
|
flat_g = Reshape((-1, shape_g[-1]))(var_g)
|
||||||
|
flat_h = Reshape((-1, shape_h[-1]))(var_h)
|
||||||
|
|
||||||
|
var_s = Lambda(lambda var_x: K.batch_dot(var_x[0],
|
||||||
|
Permute((2, 1))(var_x[1])))([flat_g, flat_f])
|
||||||
|
|
||||||
|
beta = Softmax(axis=-1)(var_s)
|
||||||
|
var_o = Lambda(lambda var_x: K.batch_dot(var_x[0], var_x[1]))([beta, flat_h])
|
||||||
|
var_o = Reshape(shape_x[1:])(var_o)
|
||||||
|
var_o = Scale()(var_o)
|
||||||
|
|
||||||
|
out = add([var_o, inp])
|
||||||
|
return out
|
289
lib/model/normalization.py
Normal file
289
lib/model/normalization.py
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Normaliztion methods for faceswap.py
|
||||||
|
Code from:
|
||||||
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from keras.engine import Layer, InputSpec
|
||||||
|
from keras import initializers, regularizers, constraints
|
||||||
|
from keras import backend as K
|
||||||
|
from keras.utils.generic_utils import get_custom_objects
|
||||||
|
|
||||||
|
|
||||||
|
def to_list(inp):
|
||||||
|
""" Convert to list """
|
||||||
|
if not isinstance(inp, (list, tuple)):
|
||||||
|
return [inp]
|
||||||
|
return list(inp)
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceNormalization(Layer):
|
||||||
|
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
||||||
|
Normalize the activations of the previous layer at each step,
|
||||||
|
i.e. applies a transformation that maintains the mean activation
|
||||||
|
close to 0 and the activation standard deviation close to 1.
|
||||||
|
# Arguments
|
||||||
|
axis: Integer, the axis that should be normalized
|
||||||
|
(typically the features axis).
|
||||||
|
For instance, after a `Conv2D` layer with
|
||||||
|
`data_format="channels_first"`,
|
||||||
|
set `axis=1` in `InstanceNormalization`.
|
||||||
|
Setting `axis=None` will normalize all values in each instance of the batch.
|
||||||
|
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
||||||
|
epsilon: Small float added to variance to avoid dividing by zero.
|
||||||
|
center: If True, add offset of `beta` to normalized tensor.
|
||||||
|
If False, `beta` is ignored.
|
||||||
|
scale: If True, multiply by `gamma`.
|
||||||
|
If False, `gamma` is not used.
|
||||||
|
When the next layer is linear (also e.g. `nn.relu`),
|
||||||
|
this can be disabled since the scaling
|
||||||
|
will be done by the next layer.
|
||||||
|
beta_initializer: Initializer for the beta weight.
|
||||||
|
gamma_initializer: Initializer for the gamma weight.
|
||||||
|
beta_regularizer: Optional regularizer for the beta weight.
|
||||||
|
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||||
|
beta_constraint: Optional constraint for the beta weight.
|
||||||
|
gamma_constraint: Optional constraint for the gamma weight.
|
||||||
|
# Input shape
|
||||||
|
Arbitrary. Use the keyword argument `input_shape`
|
||||||
|
(tuple of integers, does not include the samples axis)
|
||||||
|
when using this layer as the first layer in a model.
|
||||||
|
# Output shape
|
||||||
|
Same shape as input.
|
||||||
|
# References
|
||||||
|
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
||||||
|
- [Instance Normalization: The Missing Ingredient for Fast
|
||||||
|
Stylization](https://arxiv.org/abs/1607.08022)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
axis=None,
|
||||||
|
epsilon=1e-3,
|
||||||
|
center=True,
|
||||||
|
scale=True,
|
||||||
|
beta_initializer='zeros',
|
||||||
|
gamma_initializer='ones',
|
||||||
|
beta_regularizer=None,
|
||||||
|
gamma_regularizer=None,
|
||||||
|
beta_constraint=None,
|
||||||
|
gamma_constraint=None,
|
||||||
|
**kwargs):
|
||||||
|
self.beta = None
|
||||||
|
self.gamma = None
|
||||||
|
super(InstanceNormalization, self).__init__(**kwargs)
|
||||||
|
self.supports_masking = True
|
||||||
|
self.axis = axis
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.center = center
|
||||||
|
self.scale = scale
|
||||||
|
self.beta_initializer = initializers.get(beta_initializer)
|
||||||
|
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||||
|
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||||
|
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||||
|
self.beta_constraint = constraints.get(beta_constraint)
|
||||||
|
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
ndim = len(input_shape)
|
||||||
|
if self.axis == 0:
|
||||||
|
raise ValueError('Axis cannot be zero')
|
||||||
|
|
||||||
|
if (self.axis is not None) and (ndim == 2):
|
||||||
|
raise ValueError('Cannot specify axis for rank 1 tensor')
|
||||||
|
|
||||||
|
self.input_spec = InputSpec(ndim=ndim)
|
||||||
|
|
||||||
|
if self.axis is None:
|
||||||
|
shape = (1,)
|
||||||
|
else:
|
||||||
|
shape = (input_shape[self.axis],)
|
||||||
|
|
||||||
|
if self.scale:
|
||||||
|
self.gamma = self.add_weight(shape=shape,
|
||||||
|
name='gamma',
|
||||||
|
initializer=self.gamma_initializer,
|
||||||
|
regularizer=self.gamma_regularizer,
|
||||||
|
constraint=self.gamma_constraint)
|
||||||
|
else:
|
||||||
|
self.gamma = None
|
||||||
|
if self.center:
|
||||||
|
self.beta = self.add_weight(shape=shape,
|
||||||
|
name='beta',
|
||||||
|
initializer=self.beta_initializer,
|
||||||
|
regularizer=self.beta_regularizer,
|
||||||
|
constraint=self.beta_constraint)
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
input_shape = K.int_shape(inputs)
|
||||||
|
reduction_axes = list(range(0, len(input_shape)))
|
||||||
|
|
||||||
|
if self.axis is not None:
|
||||||
|
del reduction_axes[self.axis]
|
||||||
|
|
||||||
|
del reduction_axes[0]
|
||||||
|
|
||||||
|
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
||||||
|
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
||||||
|
normed = (inputs - mean) / stddev
|
||||||
|
|
||||||
|
broadcast_shape = [1] * len(input_shape)
|
||||||
|
if self.axis is not None:
|
||||||
|
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||||
|
|
||||||
|
if self.scale:
|
||||||
|
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
||||||
|
normed = normed * broadcast_gamma
|
||||||
|
if self.center:
|
||||||
|
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
||||||
|
normed = normed + broadcast_beta
|
||||||
|
return normed
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {
|
||||||
|
'axis': self.axis,
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'center': self.center,
|
||||||
|
'scale': self.scale,
|
||||||
|
'beta_initializer': initializers.serialize(self.beta_initializer),
|
||||||
|
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
||||||
|
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||||
|
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||||
|
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||||
|
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
||||||
|
}
|
||||||
|
base_config = super(InstanceNormalization, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormalization(Layer):
|
||||||
|
""" Group Normalization
|
||||||
|
from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
||||||
|
|
||||||
|
def __init__(self, axis=-1,
|
||||||
|
gamma_init='one', beta_init='zero',
|
||||||
|
gamma_regularizer=None, beta_regularizer=None,
|
||||||
|
epsilon=1e-6,
|
||||||
|
group=32,
|
||||||
|
data_format=None,
|
||||||
|
**kwargs):
|
||||||
|
self.beta = None
|
||||||
|
self.gamma = None
|
||||||
|
super(GroupNormalization, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.axis = to_list(axis)
|
||||||
|
self.gamma_init = initializers.get(gamma_init)
|
||||||
|
self.beta_init = initializers.get(beta_init)
|
||||||
|
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||||
|
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.group = group
|
||||||
|
self.data_format = K.normalize_data_format(data_format)
|
||||||
|
|
||||||
|
self.supports_masking = True
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.input_spec = [InputSpec(shape=input_shape)]
|
||||||
|
shape = [1 for _ in input_shape]
|
||||||
|
if self.data_format == 'channels_last':
|
||||||
|
channel_axis = -1
|
||||||
|
shape[channel_axis] = input_shape[channel_axis]
|
||||||
|
elif self.data_format == 'channels_first':
|
||||||
|
channel_axis = 1
|
||||||
|
shape[channel_axis] = input_shape[channel_axis]
|
||||||
|
# for i in self.axis:
|
||||||
|
# shape[i] = input_shape[i]
|
||||||
|
self.gamma = self.add_weight(shape=shape,
|
||||||
|
initializer=self.gamma_init,
|
||||||
|
regularizer=self.gamma_regularizer,
|
||||||
|
name='gamma')
|
||||||
|
self.beta = self.add_weight(shape=shape,
|
||||||
|
initializer=self.beta_init,
|
||||||
|
regularizer=self.beta_regularizer,
|
||||||
|
name='beta')
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, inputs, mask=None):
|
||||||
|
input_shape = K.int_shape(inputs)
|
||||||
|
if len(input_shape) != 4 and len(input_shape) != 2:
|
||||||
|
raise ValueError('Inputs should have rank ' +
|
||||||
|
str(4) + " or " + str(2) +
|
||||||
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
if len(input_shape) == 4:
|
||||||
|
if self.data_format == 'channels_last':
|
||||||
|
batch_size, height, width, channels = input_shape
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = -1
|
||||||
|
|
||||||
|
if channels < self.group:
|
||||||
|
raise ValueError('Input channels should be larger than group size' +
|
||||||
|
'; Received input channels: ' + str(channels) +
|
||||||
|
'; Group size: ' + str(self.group))
|
||||||
|
|
||||||
|
var_x = K.reshape(inputs, (batch_size,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
self.group,
|
||||||
|
channels // self.group))
|
||||||
|
mean = K.mean(var_x, axis=[1, 2, 4], keepdims=True)
|
||||||
|
std = K.sqrt(K.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon)
|
||||||
|
var_x = (var_x - mean) / std
|
||||||
|
|
||||||
|
var_x = K.reshape(var_x, (batch_size, height, width, channels))
|
||||||
|
retval = self.gamma * var_x + self.beta
|
||||||
|
elif self.data_format == 'channels_first':
|
||||||
|
batch_size, channels, height, width = input_shape
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = -1
|
||||||
|
|
||||||
|
if channels < self.group:
|
||||||
|
raise ValueError('Input channels should be larger than group size' +
|
||||||
|
'; Received input channels: ' + str(channels) +
|
||||||
|
'; Group size: ' + str(self.group))
|
||||||
|
|
||||||
|
var_x = K.reshape(inputs, (batch_size,
|
||||||
|
self.group,
|
||||||
|
channels // self.group,
|
||||||
|
height,
|
||||||
|
width))
|
||||||
|
mean = K.mean(var_x, axis=[2, 3, 4], keepdims=True)
|
||||||
|
std = K.sqrt(K.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon)
|
||||||
|
var_x = (var_x - mean) / std
|
||||||
|
|
||||||
|
var_x = K.reshape(var_x, (batch_size, channels, height, width))
|
||||||
|
retval = self.gamma * var_x + self.beta
|
||||||
|
|
||||||
|
elif len(input_shape) == 2:
|
||||||
|
reduction_axes = list(range(0, len(input_shape)))
|
||||||
|
del reduction_axes[0]
|
||||||
|
batch_size, _ = input_shape
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = -1
|
||||||
|
|
||||||
|
mean = K.mean(inputs, keepdims=True)
|
||||||
|
std = K.sqrt(K.var(inputs, keepdims=True) + self.epsilon)
|
||||||
|
var_x = (inputs - mean) / std
|
||||||
|
|
||||||
|
retval = self.gamma * var_x + self.beta
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'epsilon': self.epsilon,
|
||||||
|
'axis': self.axis,
|
||||||
|
'gamma_init': initializers.serialize(self.gamma_init),
|
||||||
|
'beta_init': initializers.serialize(self.beta_init),
|
||||||
|
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||||
|
'beta_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||||
|
'group': self.group}
|
||||||
|
base_config = super(GroupNormalization, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
# Update normalizations into Keras custom objects
|
||||||
|
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||||
|
if inspect.isclass(obj) and obj.__module__ == __name__:
|
||||||
|
get_custom_objects().update({name: obj})
|
|
@ -117,6 +117,8 @@ class FSThread(threading.Thread):
|
||||||
self._target(*self._args, **self._kwargs)
|
self._target(*self._args, **self._kwargs)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
self.err = sys.exc_info()
|
self.err = sys.exc_info()
|
||||||
|
logger.debug("Error in thread (%s): %s", self._name,
|
||||||
|
self.err[1].with_traceback(self.err[2]))
|
||||||
finally:
|
finally:
|
||||||
# Avoid a refcycle if the thread is running a function with
|
# Avoid a refcycle if the thread is running a function with
|
||||||
# an argument that has a member that points to the thread.
|
# an argument that has a member that points to the thread.
|
||||||
|
@ -126,8 +128,8 @@ class FSThread(threading.Thread):
|
||||||
class MultiThread():
|
class MultiThread():
|
||||||
""" Threading for IO heavy ops
|
""" Threading for IO heavy ops
|
||||||
Catches errors in thread and rethrows to parent """
|
Catches errors in thread and rethrows to parent """
|
||||||
def __init__(self, target, *args, thread_count=1, **kwargs):
|
def __init__(self, target, *args, thread_count=1, name=None, **kwargs):
|
||||||
self._name = target.__name__
|
self._name = name if name else target.__name__
|
||||||
logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
|
logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
|
||||||
self.__class__.__name__, self._name, thread_count)
|
self.__class__.__name__, self._name, thread_count)
|
||||||
logger.trace("args: %s, kwargs: %s", args, kwargs)
|
logger.trace("args: %s, kwargs: %s", args, kwargs)
|
||||||
|
@ -139,6 +141,16 @@ class MultiThread():
|
||||||
self._kwargs = kwargs
|
self._kwargs = kwargs
|
||||||
logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name)
|
logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_error(self):
|
||||||
|
""" Return true if a thread has errored, otherwise false """
|
||||||
|
return any(thread.err for thread in self._threads)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def errors(self):
|
||||||
|
""" Return a list of thread errors """
|
||||||
|
return [thread.err for thread in self._threads]
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
""" Start a thread with the given method and args """
|
""" Start a thread with the given method and args """
|
||||||
logger.debug("Starting thread(s): '%s'", self._name)
|
logger.debug("Starting thread(s): '%s'", self._name)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import multiprocessing as mp
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from queue import Empty as QueueEmpty # pylint: disable=unused-import; # noqa
|
from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -37,7 +37,7 @@ class QueueManager():
|
||||||
self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue()
|
self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue()
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def add_queue(self, name, maxsize=0):
|
def add_queue(self, name, maxsize=0, multiprocessing_queue=True):
|
||||||
""" Add a queue to the manager
|
""" Add a queue to the manager
|
||||||
|
|
||||||
Adds an event "shutdown" to the queue that can be used to indicate
|
Adds an event "shutdown" to the queue that can be used to indicate
|
||||||
|
@ -46,7 +46,12 @@ class QueueManager():
|
||||||
logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize)
|
logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize)
|
||||||
if name in self.queues.keys():
|
if name in self.queues.keys():
|
||||||
raise ValueError("Queue '{}' already exists.".format(name))
|
raise ValueError("Queue '{}' already exists.".format(name))
|
||||||
queue = self.manager.Queue(maxsize=maxsize)
|
|
||||||
|
if multiprocessing_queue:
|
||||||
|
queue = self.manager.Queue(maxsize=maxsize)
|
||||||
|
else:
|
||||||
|
queue = Queue(maxsize=maxsize)
|
||||||
|
|
||||||
setattr(queue, "shutdown", self.shutdown)
|
setattr(queue, "shutdown", self.shutdown)
|
||||||
self.queues[name] = queue
|
self.queues[name] = queue
|
||||||
logger.debug("QueueManager added: (name: '%s')", name)
|
logger.debug("QueueManager added: (name: '%s')", name)
|
||||||
|
|
|
@ -1,106 +1,401 @@
|
||||||
from random import shuffle
|
#!/usr/bin/env python3
|
||||||
import cv2
|
""" Process training data for model training """
|
||||||
import numpy
|
|
||||||
|
import logging
|
||||||
from .multithreading import BackgroundGenerator
|
|
||||||
from .umeyama import umeyama
|
from hashlib import sha1
|
||||||
|
from random import shuffle
|
||||||
class TrainingDataGenerator():
|
|
||||||
def __init__(self, random_transform_args, coverage, scale=5, zoom=1): #TODO thos default should stay in the warp function
|
import cv2
|
||||||
self.random_transform_args = random_transform_args
|
import numpy as np
|
||||||
self.coverage = coverage
|
from scipy.interpolate import griddata
|
||||||
self.scale = scale
|
|
||||||
self.zoom = zoom
|
from lib.model import masks
|
||||||
|
from lib.multithreading import MultiThread
|
||||||
def minibatchAB(self, images, batchsize, doShuffle=True):
|
from lib.queue_manager import queue_manager
|
||||||
batch = BackgroundGenerator(self.minibatch(images, batchsize, doShuffle), 1)
|
from lib.umeyama import umeyama
|
||||||
for ep1, warped_img, target_img in batch.iterator():
|
|
||||||
yield ep1, warped_img, target_img
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
# A generator function that yields epoch, batchsize of warped_img and batchsize of target_img
|
|
||||||
def minibatch(self, data, batchsize, doShuffle=True):
|
class TrainingDataGenerator():
|
||||||
length = len(data)
|
""" Generate training data for models """
|
||||||
assert length >= batchsize, "Number of images is lower than batch-size (Note that too few images may lead to bad training). # images: {}, batch-size: {}".format(length, batchsize)
|
def __init__(self, model_input_size, model_output_size, training_opts):
|
||||||
epoch = i = 0
|
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
|
||||||
if doShuffle:
|
"training_opts: %s, landmarks: %s)",
|
||||||
shuffle(data)
|
self.__class__.__name__, model_input_size, model_output_size,
|
||||||
while True:
|
{key: val for key, val in training_opts.items() if key != "landmarks"},
|
||||||
size = batchsize
|
bool(training_opts.get("landmarks", None)))
|
||||||
if i+size > length:
|
self.batchsize = 0
|
||||||
if doShuffle:
|
self.model_input_size = model_input_size
|
||||||
shuffle(data)
|
self.training_opts = training_opts
|
||||||
i = 0
|
self.mask_function = self.set_mask_function()
|
||||||
epoch+=1
|
self.landmarks = self.training_opts.get("landmarks", None)
|
||||||
rtn = numpy.float32([self.read_image(img) for img in data[i:i+size]])
|
|
||||||
i+=size
|
self.processing = ImageManipulation(model_input_size,
|
||||||
yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:]
|
model_output_size,
|
||||||
|
training_opts.get("coverage_ratio", 0.625))
|
||||||
def color_adjust(self, img):
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
return img / 255.0
|
|
||||||
|
def set_mask_function(self):
|
||||||
def read_image(self, fn):
|
""" Set the mask function to use if using mask """
|
||||||
try:
|
mask_type = self.training_opts.get("mask_type", None)
|
||||||
image = self.color_adjust(cv2.imread(fn))
|
if mask_type:
|
||||||
except TypeError:
|
logger.debug("Mask type: '%s'", mask_type)
|
||||||
raise Exception("Error while reading image", fn)
|
mask_func = getattr(masks, mask_type)
|
||||||
|
else:
|
||||||
image = cv2.resize(image, (256,256))
|
mask_func = None
|
||||||
image = self.random_transform( image, **self.random_transform_args )
|
logger.debug("Mask function: %s", mask_func)
|
||||||
warped_img, target_img = self.random_warp( image, self.coverage, self.scale, self.zoom )
|
return mask_func
|
||||||
|
|
||||||
return warped_img, target_img
|
def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False):
|
||||||
|
""" Keep a queue filled to 8x Batch Size """
|
||||||
def random_transform(self, image, rotation_range, zoom_range, shift_range, random_flip):
|
logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, "
|
||||||
h, w = image.shape[0:2]
|
"is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse)
|
||||||
rotation = numpy.random.uniform(-rotation_range, rotation_range)
|
self.batchsize = batchsize
|
||||||
scale = numpy.random.uniform(1 - zoom_range, 1 + zoom_range)
|
q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side)
|
||||||
tx = numpy.random.uniform(-shift_range, shift_range) * w
|
q_size = batchsize * 8
|
||||||
ty = numpy.random.uniform(-shift_range, shift_range) * h
|
# Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays
|
||||||
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
|
queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False)
|
||||||
mat[:, 2] += (tx, ty)
|
load_thread = MultiThread(self.load_batches,
|
||||||
result = cv2.warpAffine(
|
images,
|
||||||
image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE)
|
q_name,
|
||||||
if numpy.random.random() < random_flip:
|
side,
|
||||||
result = result[:, ::-1]
|
is_timelapse,
|
||||||
return result
|
do_shuffle)
|
||||||
|
load_thread.start()
|
||||||
# get pair of random warped images from aligned face image
|
logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name)
|
||||||
def random_warp(self, image, coverage, scale = 5, zoom = 1):
|
return self.minibatch(q_name, load_thread)
|
||||||
assert image.shape == (256, 256, 3)
|
|
||||||
range_ = numpy.linspace(128 - coverage//2, 128 + coverage//2, 5)
|
def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True):
|
||||||
mapx = numpy.broadcast_to(range_, (5, 5))
|
""" Load the warped images and target images to queue """
|
||||||
mapy = mapx.T
|
logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', "
|
||||||
|
"is_timelapse: %s, do_shuffle: %s)",
|
||||||
mapx = mapx + numpy.random.normal(size=(5,5), scale=scale)
|
len(images), q_name, side, is_timelapse, do_shuffle)
|
||||||
mapy = mapy + numpy.random.normal(size=(5,5), scale=scale)
|
epoch = 0
|
||||||
|
queue = queue_manager.get_queue(q_name)
|
||||||
interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
|
self.validate_samples(images)
|
||||||
interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
|
while True:
|
||||||
|
if do_shuffle:
|
||||||
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
|
shuffle(images)
|
||||||
|
for img in images:
|
||||||
src_points = numpy.stack([mapx.ravel(), mapy.ravel() ], axis=-1)
|
logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side)
|
||||||
dst_points = numpy.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2)
|
queue.put(self.process_face(img, side, is_timelapse))
|
||||||
mat = umeyama(src_points, dst_points, True)[0:2]
|
epoch += 1
|
||||||
|
logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')",
|
||||||
target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom))
|
epoch, q_name, side)
|
||||||
|
|
||||||
return warped_image, target_image
|
def validate_samples(self, data):
|
||||||
|
""" Check the total number of images against batchsize and return
|
||||||
def stack_images(images):
|
the total number of images """
|
||||||
def get_transpose_axes(n):
|
length = len(data)
|
||||||
if n % 2 == 0:
|
msg = ("Number of images is lower than batch-size (Note that too few "
|
||||||
y_axes = list(range(1, n - 1, 2))
|
"images may lead to bad training). # images: {}, "
|
||||||
x_axes = list(range(0, n - 1, 2))
|
"batch-size: {}".format(length, self.batchsize))
|
||||||
else:
|
assert length >= self.batchsize, msg
|
||||||
y_axes = list(range(0, n - 1, 2))
|
|
||||||
x_axes = list(range(1, n - 1, 2))
|
def minibatch(self, q_name, load_thread):
|
||||||
return y_axes, x_axes, [n - 1]
|
""" A generator function that yields epoch, batchsize of warped_img
|
||||||
|
and batchsize of target_img from the load queue """
|
||||||
images_shape = numpy.array(images.shape)
|
logger.debug("Launching minibatch generator for queue: '%s'", q_name)
|
||||||
new_axes = get_transpose_axes(len(images_shape))
|
queue = queue_manager.get_queue(q_name)
|
||||||
new_shape = [numpy.prod(images_shape[x]) for x in new_axes]
|
while True:
|
||||||
return numpy.transpose(
|
if load_thread.has_error:
|
||||||
images,
|
logger.debug("Thread error detected")
|
||||||
axes=numpy.concatenate(new_axes)
|
break
|
||||||
).reshape(new_shape)
|
batch = list()
|
||||||
|
for _ in range(self.batchsize):
|
||||||
|
images = queue.get()
|
||||||
|
for idx, image in enumerate(images):
|
||||||
|
if len(batch) < idx + 1:
|
||||||
|
batch.append(list())
|
||||||
|
batch[idx].append(image)
|
||||||
|
batch = [np.float32(image) for image in batch]
|
||||||
|
logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'",
|
||||||
|
len(batch), [item.shape for item in batch], q_name)
|
||||||
|
yield batch
|
||||||
|
logger.debug("Finished minibatch generator for queue: '%s'", q_name)
|
||||||
|
load_thread.join()
|
||||||
|
|
||||||
|
def process_face(self, filename, side, is_timelapse):
|
||||||
|
""" Load an image and perform transformation and warping """
|
||||||
|
logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)",
|
||||||
|
filename, side, is_timelapse)
|
||||||
|
try:
|
||||||
|
image = cv2.imread(filename) # pylint: disable=no-member
|
||||||
|
except TypeError:
|
||||||
|
raise Exception("Error while reading image", filename)
|
||||||
|
|
||||||
|
if self.mask_function or self.training_opts["warp_to_landmarks"]:
|
||||||
|
src_pts = self.get_landmarks(filename, image, side)
|
||||||
|
if self.mask_function:
|
||||||
|
image = self.mask_function(src_pts, image, channels=4)
|
||||||
|
|
||||||
|
image = self.processing.color_adjust(image)
|
||||||
|
|
||||||
|
if not is_timelapse:
|
||||||
|
image = self.processing.random_transform(image)
|
||||||
|
if not self.training_opts["no_flip"]:
|
||||||
|
image = self.processing.do_random_flip(image)
|
||||||
|
sample = image.copy()[:, :, :3]
|
||||||
|
|
||||||
|
if self.training_opts["warp_to_landmarks"]:
|
||||||
|
dst_pts = self.get_closest_match(filename, side, src_pts)
|
||||||
|
processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts)
|
||||||
|
else:
|
||||||
|
processed = self.processing.random_warp(image)
|
||||||
|
|
||||||
|
processed.insert(0, sample)
|
||||||
|
logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)",
|
||||||
|
filename, side, [img.shape for img in processed])
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def get_landmarks(self, filename, image, side):
|
||||||
|
""" Return the landmarks for this face """
|
||||||
|
logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side)
|
||||||
|
lm_key = sha1(image).hexdigest()
|
||||||
|
try:
|
||||||
|
src_points = self.landmarks[side][lm_key]
|
||||||
|
except KeyError:
|
||||||
|
raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key,
|
||||||
|
filename))
|
||||||
|
logger.trace("Returning: (src_points: %s)", src_points)
|
||||||
|
return src_points
|
||||||
|
|
||||||
|
def get_closest_match(self, filename, side, src_points):
|
||||||
|
""" Return closest matched landmarks from opposite set """
|
||||||
|
logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'",
|
||||||
|
filename, src_points)
|
||||||
|
dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"]
|
||||||
|
dst_points = list(dst_points.values())
|
||||||
|
closest = (np.mean(np.square(src_points - dst_points),
|
||||||
|
axis=(1, 2))).argsort()[:10]
|
||||||
|
closest = np.random.choice(closest)
|
||||||
|
dst_points = dst_points[closest]
|
||||||
|
logger.trace("Returning: (dst_points: %s)", dst_points)
|
||||||
|
return dst_points
|
||||||
|
|
||||||
|
|
||||||
|
class ImageManipulation():
|
||||||
|
""" Manipulations to be performed on training images """
|
||||||
|
def __init__(self, input_size, output_size, coverage_ratio):
|
||||||
|
""" input_size: Size of the face input into the model
|
||||||
|
output_size: Size of the face that comes out of the modell
|
||||||
|
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
||||||
|
"""
|
||||||
|
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
|
||||||
|
self.__class__.__name__, input_size, output_size, coverage_ratio)
|
||||||
|
# Transform args
|
||||||
|
self.rotation_range = 10 # Range to randomly rotate the image by
|
||||||
|
self.zoom_range = 0.05 # Range to randomly zoom the image by
|
||||||
|
self.shift_range = 0.05 # Range to randomly translate the image by
|
||||||
|
self.random_flip = 0.5 # Chance to flip the image horizontally
|
||||||
|
# Transform and Warp args
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
# Warp args
|
||||||
|
self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
||||||
|
self.scale = 5 # Normal random variable scale
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def color_adjust(img):
|
||||||
|
""" Color adjust RGB image """
|
||||||
|
logger.trace("Color adjusting image")
|
||||||
|
return img.astype('float32') / 255.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def separate_mask(image):
|
||||||
|
""" Return the image and the mask from a 4 channel image """
|
||||||
|
mask = None
|
||||||
|
if image.shape[2] == 4:
|
||||||
|
logger.trace("Image contains mask")
|
||||||
|
mask = np.expand_dims(image[:, :, -1], axis=2)
|
||||||
|
image = image[:, :, :3]
|
||||||
|
else:
|
||||||
|
logger.trace("Image has no mask")
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
def get_coverage(self, image):
|
||||||
|
""" Return coverage value for given image """
|
||||||
|
coverage = int(image.shape[0] * self.coverage_ratio)
|
||||||
|
logger.trace("Coverage: %s", coverage)
|
||||||
|
return coverage
|
||||||
|
|
||||||
|
def random_transform(self, image):
|
||||||
|
""" Randomly transform an image """
|
||||||
|
logger.trace("Randomly transforming image")
|
||||||
|
height, width = image.shape[0:2]
|
||||||
|
|
||||||
|
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
|
||||||
|
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
|
||||||
|
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
|
||||||
|
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
|
||||||
|
|
||||||
|
mat = cv2.getRotationMatrix2D( # pylint: disable=no-member
|
||||||
|
(width // 2, height // 2), rotation, scale)
|
||||||
|
mat[:, 2] += (tnx, tny)
|
||||||
|
result = cv2.warpAffine( # pylint: disable=no-member
|
||||||
|
image, mat, (width, height),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member
|
||||||
|
|
||||||
|
logger.trace("Randomly transformed image")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def do_random_flip(self, image):
|
||||||
|
""" Perform flip on image if random number is within threshold """
|
||||||
|
logger.trace("Randomly flipping image")
|
||||||
|
if np.random.random() < self.random_flip:
|
||||||
|
logger.trace("Flip within threshold. Flipping")
|
||||||
|
retval = image[:, ::-1]
|
||||||
|
else:
|
||||||
|
logger.trace("Flip outside threshold. Not Flipping")
|
||||||
|
retval = image
|
||||||
|
logger.trace("Randomly flipped image")
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def random_warp(self, image):
|
||||||
|
""" get pair of random warped images from aligned face image """
|
||||||
|
logger.trace("Randomly warping image")
|
||||||
|
height, width = image.shape[0:2]
|
||||||
|
coverage = self.get_coverage(image)
|
||||||
|
assert height == width and height % 2 == 0
|
||||||
|
|
||||||
|
range_ = np.linspace(height // 2 - coverage // 2,
|
||||||
|
height // 2 + coverage // 2,
|
||||||
|
5, dtype='float32')
|
||||||
|
mapx = np.broadcast_to(range_, (5, 5)).copy()
|
||||||
|
mapy = mapx.T
|
||||||
|
# mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast
|
||||||
|
|
||||||
|
pad = int(1.25 * self.input_size)
|
||||||
|
slices = slice(pad // 10, -pad // 10)
|
||||||
|
dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4))
|
||||||
|
interp = np.empty((2, self.input_size, self.input_size), dtype='float32')
|
||||||
|
####
|
||||||
|
|
||||||
|
for i, map_ in enumerate([mapx, mapy]):
|
||||||
|
map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale)
|
||||||
|
interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member
|
||||||
|
|
||||||
|
warped_image = cv2.remap( # pylint: disable=no-member
|
||||||
|
image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member
|
||||||
|
logger.trace("Warped image shape: %s", warped_image.shape)
|
||||||
|
|
||||||
|
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
|
||||||
|
dst_points = np.mgrid[dst_slice, dst_slice]
|
||||||
|
mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2]
|
||||||
|
target_image = cv2.warpAffine( # pylint: disable=no-member
|
||||||
|
image, mat, (self.output_size, self.output_size))
|
||||||
|
logger.trace("Target image shape: %s", target_image.shape)
|
||||||
|
|
||||||
|
warped_image, warped_mask = self.separate_mask(warped_image)
|
||||||
|
target_image, target_mask = self.separate_mask(target_image)
|
||||||
|
|
||||||
|
if target_mask is None:
|
||||||
|
logger.trace("Randomly warped image")
|
||||||
|
return [warped_image, target_image]
|
||||||
|
|
||||||
|
logger.trace("Target mask shape: %s", target_mask.shape)
|
||||||
|
logger.trace("Randomly warped image and mask")
|
||||||
|
return [warped_image, target_image, target_mask]
|
||||||
|
|
||||||
|
def random_warp_landmarks(self, image, src_points=None, dst_points=None):
|
||||||
|
""" get warped image, target image and target mask
|
||||||
|
From DFAKER plugin """
|
||||||
|
logger.trace("Randomly warping landmarks")
|
||||||
|
size = image.shape[0]
|
||||||
|
coverage = self.get_coverage(image)
|
||||||
|
|
||||||
|
p_mx = size - 1
|
||||||
|
p_hf = (size // 2) - 1
|
||||||
|
|
||||||
|
edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0),
|
||||||
|
(p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]
|
||||||
|
grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)]
|
||||||
|
|
||||||
|
source = src_points
|
||||||
|
destination = (dst_points.copy().astype('float32') +
|
||||||
|
np.random.normal(size=dst_points.shape, scale=2.0))
|
||||||
|
destination = destination.astype('uint8')
|
||||||
|
|
||||||
|
face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member
|
||||||
|
[source[17:], destination[17:]], axis=0).astype(int))
|
||||||
|
|
||||||
|
source = [(pty, ptx) for ptx, pty in source] + edge_anchors
|
||||||
|
destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors
|
||||||
|
|
||||||
|
indicies_to_remove = set()
|
||||||
|
for fpl in source, destination:
|
||||||
|
for idx, (pty, ptx) in enumerate(fpl):
|
||||||
|
if idx > 17:
|
||||||
|
break
|
||||||
|
elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member
|
||||||
|
(pty, ptx),
|
||||||
|
False) >= 0:
|
||||||
|
indicies_to_remove.add(idx)
|
||||||
|
|
||||||
|
for idx in sorted(indicies_to_remove, reverse=True):
|
||||||
|
source.pop(idx)
|
||||||
|
destination.pop(idx)
|
||||||
|
|
||||||
|
grid_z = griddata(destination, source, (grid_x, grid_y), method="linear")
|
||||||
|
map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size)
|
||||||
|
map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size)
|
||||||
|
map_x_32 = map_x.astype('float32')
|
||||||
|
map_y_32 = map_y.astype('float32')
|
||||||
|
|
||||||
|
warped_image = cv2.remap(image, # pylint: disable=no-member
|
||||||
|
map_x_32,
|
||||||
|
map_y_32,
|
||||||
|
cv2.INTER_LINEAR, # pylint: disable=no-member
|
||||||
|
cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
|
||||||
|
target_image = image
|
||||||
|
|
||||||
|
# TODO Make sure this replacement is correct
|
||||||
|
slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2)
|
||||||
|
# slices = slice(size // 32, size - size // 32) # 8px on a 256px image
|
||||||
|
warped_image = cv2.resize( # pylint: disable=no-member
|
||||||
|
warped_image[slices, slices, :], (self.input_size, self.input_size),
|
||||||
|
cv2.INTER_AREA) # pylint: disable=no-member
|
||||||
|
logger.trace("Warped image shape: %s", warped_image.shape)
|
||||||
|
target_image = cv2.resize( # pylint: disable=no-member
|
||||||
|
target_image[slices, slices, :], (self.output_size, self.output_size),
|
||||||
|
cv2.INTER_AREA) # pylint: disable=no-member
|
||||||
|
logger.trace("Target image shape: %s", target_image.shape)
|
||||||
|
|
||||||
|
warped_image, warped_mask = self.separate_mask(warped_image)
|
||||||
|
target_image, target_mask = self.separate_mask(target_image)
|
||||||
|
|
||||||
|
if target_mask is None:
|
||||||
|
logger.trace("Randomly warped image")
|
||||||
|
return [warped_image, target_image]
|
||||||
|
|
||||||
|
logger.trace("Target mask shape: %s", target_mask.shape)
|
||||||
|
logger.trace("Randomly warped image and mask")
|
||||||
|
return [warped_image, target_image, target_mask]
|
||||||
|
|
||||||
|
|
||||||
|
def stack_images(images):
|
||||||
|
""" Stack images """
|
||||||
|
logger.debug("Stack images")
|
||||||
|
|
||||||
|
def get_transpose_axes(num):
|
||||||
|
if num % 2 == 0:
|
||||||
|
logger.debug("Even number of images to stack")
|
||||||
|
y_axes = list(range(1, num - 1, 2))
|
||||||
|
x_axes = list(range(0, num - 1, 2))
|
||||||
|
else:
|
||||||
|
logger.debug("Odd number of images to stack")
|
||||||
|
y_axes = list(range(0, num - 1, 2))
|
||||||
|
x_axes = list(range(1, num - 1, 2))
|
||||||
|
return y_axes, x_axes, [num - 1]
|
||||||
|
|
||||||
|
images_shape = np.array(images.shape)
|
||||||
|
new_axes = get_transpose_axes(len(images_shape))
|
||||||
|
new_shape = [np.prod(images_shape[x]) for x in new_axes]
|
||||||
|
logger.debug("Stacked images")
|
||||||
|
return np.transpose(
|
||||||
|
images,
|
||||||
|
axes=np.concatenate(new_axes)
|
||||||
|
).reshape(new_shape)
|
||||||
|
|
|
@ -12,8 +12,27 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
MEAN_FACE_X = np.array([
|
||||||
|
0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483,
|
||||||
|
0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127,
|
||||||
|
0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122,
|
||||||
|
0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132,
|
||||||
|
0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127,
|
||||||
|
.551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532,
|
||||||
|
0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364,
|
||||||
|
0.490127, 0.42689])
|
||||||
|
|
||||||
def umeyama(src, dst, estimate_scale):
|
MEAN_FACE_Y = np.array([
|
||||||
|
0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891,
|
||||||
|
0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625,
|
||||||
|
0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758,
|
||||||
|
0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758,
|
||||||
|
0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578,
|
||||||
|
0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192,
|
||||||
|
0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182,
|
||||||
|
0.831803, 0.824182])
|
||||||
|
|
||||||
|
def umeyama(src, estimate_scale, dst=None):
|
||||||
"""Estimate N-D similarity transformation with or without scaling.
|
"""Estimate N-D similarity transformation with or without scaling.
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -33,6 +52,8 @@ def umeyama(src, dst, estimate_scale):
|
||||||
.. [1] "Least-squares estimation of transformation parameters between two
|
.. [1] "Least-squares estimation of transformation parameters between two
|
||||||
point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573
|
point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573
|
||||||
"""
|
"""
|
||||||
|
if dst is None:
|
||||||
|
dst = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1)
|
||||||
|
|
||||||
num = src.shape[0]
|
num = src.shape[0]
|
||||||
dim = src.shape[1]
|
dim = src.shape[1]
|
||||||
|
|
103
lib/utils.py
103
lib/utils.py
|
@ -8,7 +8,6 @@ import warnings
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from re import finditer
|
from re import finditer
|
||||||
from time import time
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -16,7 +15,6 @@ import numpy as np
|
||||||
import dlib
|
import dlib
|
||||||
|
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
from lib.training_data import TrainingDataGenerator
|
|
||||||
from lib.logger import get_loglevel
|
from lib.logger import get_loglevel
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,7 +60,7 @@ def get_image_paths(directory):
|
||||||
|
|
||||||
|
|
||||||
def hash_image_file(filename):
|
def hash_image_file(filename):
|
||||||
""" Return the filename with it's sha1 hash """
|
""" Return an image file's sha1 hash """
|
||||||
img = cv2.imread(filename) # pylint: disable=no-member
|
img = cv2.imread(filename) # pylint: disable=no-member
|
||||||
img_hash = sha1(img).hexdigest()
|
img_hash = sha1(img).hexdigest()
|
||||||
logger.trace("filename: '%s', hash: %s", filename, img_hash)
|
logger.trace("filename: '%s', hash: %s", filename, img_hash)
|
||||||
|
@ -107,33 +105,12 @@ def set_system_verbosity(loglevel):
|
||||||
logger.debug("System Verbosity level: %s", loglevel)
|
logger.debug("System Verbosity level: %s", loglevel)
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel
|
||||||
if loglevel != '0':
|
if loglevel != '0':
|
||||||
for warncat in (FutureWarning, DeprecationWarning):
|
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
|
||||||
warnings.simplefilter(action='ignore', category=warncat)
|
warnings.simplefilter(action='ignore', category=warncat)
|
||||||
|
|
||||||
|
|
||||||
def add_alpha_channel(image, intensity=100):
|
|
||||||
""" Add an alpha channel to an image
|
|
||||||
|
|
||||||
intensity: The opacity of the alpha channel between 0 and 100
|
|
||||||
100 = transparent,
|
|
||||||
0 = solid """
|
|
||||||
logger.trace("Adding alpha channel: intensity: %s", intensity)
|
|
||||||
assert 0 <= intensity <= 100, "Invalid intensity supplied"
|
|
||||||
intensity = (255.0 / 100.0) * intensity
|
|
||||||
|
|
||||||
d_type = image.dtype
|
|
||||||
image = image.astype("float32")
|
|
||||||
|
|
||||||
ch_b, ch_g, ch_r = cv2.split(image) # pylint: disable=no-member
|
|
||||||
ch_a = np.ones(ch_b.shape, dtype="float32") * intensity
|
|
||||||
|
|
||||||
image_bgra = cv2.merge( # pylint: disable=no-member
|
|
||||||
(ch_b, ch_g, ch_r, ch_a))
|
|
||||||
logger.trace("Added alpha channel", intensity)
|
|
||||||
return image_bgra.astype(d_type)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_landmarks(face, rotation_matrix):
|
def rotate_landmarks(face, rotation_matrix):
|
||||||
|
# pylint: disable=c-extension-no-member
|
||||||
""" Rotate the landmarks and bounding box for faces
|
""" Rotate the landmarks and bounding box for faces
|
||||||
found in rotated images.
|
found in rotated images.
|
||||||
Pass in a DetectedFace object, Alignments dict or DLib rectangle"""
|
Pass in a DetectedFace object, Alignments dict or DLib rectangle"""
|
||||||
|
@ -223,80 +200,6 @@ def camel_case_split(identifier):
|
||||||
return [m.group(0) for m in matches]
|
return [m.group(0) for m in matches]
|
||||||
|
|
||||||
|
|
||||||
class Timelapse:
|
|
||||||
""" Time lapse function for training """
|
|
||||||
@classmethod
|
|
||||||
def create_timelapse(cls, input_dir_a, input_dir_b, output_dir, trainer):
|
|
||||||
""" Create the time lapse """
|
|
||||||
if input_dir_a is None and input_dir_b is None and output_dir is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if input_dir_a is None or input_dir_b is None:
|
|
||||||
raise ValueError("To enable the timelapse, you have to supply "
|
|
||||||
"all the parameters (--timelapse-input-A and "
|
|
||||||
"--timelapse-input-B).")
|
|
||||||
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = get_folder(os.path.join(trainer.model.model_dir,
|
|
||||||
"timelapse"))
|
|
||||||
|
|
||||||
return Timelapse(input_dir_a, input_dir_b, output_dir, trainer)
|
|
||||||
|
|
||||||
def __init__(self, input_dir_a, input_dir_b, output, trainer):
|
|
||||||
self.output_dir = output
|
|
||||||
self.trainer = trainer
|
|
||||||
|
|
||||||
if not os.path.isdir(self.output_dir):
|
|
||||||
logger.error("'%s' does not exist", self.output_dir)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
self.files_a = self.read_input_images(input_dir_a)
|
|
||||||
self.files_b = self.read_input_images(input_dir_b)
|
|
||||||
|
|
||||||
btchsz = min(len(self.files_a), len(self.files_b))
|
|
||||||
|
|
||||||
self.images_a = self.get_image_data(self.files_a, btchsz)
|
|
||||||
self.images_b = self.get_image_data(self.files_b, btchsz)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def read_input_images(input_dir):
|
|
||||||
""" Get the image paths """
|
|
||||||
if not os.path.isdir(input_dir):
|
|
||||||
logger.error("'%s' does not exist", input_dir)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
if not os.listdir(input_dir):
|
|
||||||
logger.error("'%s' contains no images", input_dir)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return get_image_paths(input_dir)
|
|
||||||
|
|
||||||
def get_image_data(self, input_images, batch_size):
|
|
||||||
""" Get training images """
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 0,
|
|
||||||
'zoom_range': 0,
|
|
||||||
'shift_range': 0,
|
|
||||||
'random_flip': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
zoom = 1
|
|
||||||
if hasattr(self.trainer.model, 'IMAGE_SHAPE'):
|
|
||||||
zoom = self.trainer.model.IMAGE_SHAPE[0] // 64
|
|
||||||
|
|
||||||
generator = TrainingDataGenerator(random_transform_args, 160, zoom)
|
|
||||||
batch = generator.minibatchAB(input_images, batch_size,
|
|
||||||
doShuffle=False)
|
|
||||||
|
|
||||||
return next(batch)[2]
|
|
||||||
|
|
||||||
def work(self):
|
|
||||||
""" Write out timelapse image """
|
|
||||||
image = self.trainer.show_sample(self.images_a, self.images_b)
|
|
||||||
cv2.imwrite(os.path.join(self.output_dir, # pylint: disable=no-member
|
|
||||||
str(int(time())) + ".png"), image)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_shutdown():
|
def safe_shutdown():
|
||||||
""" Close queues, threads and processes in event of crash """
|
""" Close queues, threads and processes in event of crash """
|
||||||
logger.debug("Safely shutting down")
|
logger.debug("Safely shutting down")
|
||||||
|
|
|
@ -1,114 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
""" Adjust converter for faceswap.py
|
|
||||||
|
|
||||||
Based on the original https://www.reddit.com/r/deepfakes/ code sample
|
|
||||||
Adjust code made by https://github.com/yangchen8710 """
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lib.utils import add_alpha_channel
|
|
||||||
|
|
||||||
|
|
||||||
class Convert():
|
|
||||||
""" Adjust Converter """
|
|
||||||
def __init__(self, encoder, smooth_mask=True, avg_color_adjust=True,
|
|
||||||
draw_transparent=False, **kwargs):
|
|
||||||
self.encoder = encoder
|
|
||||||
|
|
||||||
self.use_smooth_mask = smooth_mask
|
|
||||||
self.use_avg_color_adjust = avg_color_adjust
|
|
||||||
self.draw_transparent = draw_transparent
|
|
||||||
|
|
||||||
def patch_image(self, frame, detected_face, size):
|
|
||||||
""" Patch swapped face onto original image """
|
|
||||||
# pylint: disable=no-member
|
|
||||||
# assert image.shape == (256, 256, 3)
|
|
||||||
padding = 48
|
|
||||||
face_size = 256
|
|
||||||
detected_face.load_aligned(frame, face_size, padding,
|
|
||||||
align_eyes=False)
|
|
||||||
src_face = detected_face.aligned_face
|
|
||||||
|
|
||||||
crop = slice(padding, face_size - padding)
|
|
||||||
process_face = src_face[crop, crop]
|
|
||||||
old_face = process_face.copy()
|
|
||||||
|
|
||||||
process_face = cv2.resize(process_face,
|
|
||||||
(size, size),
|
|
||||||
interpolation=cv2.INTER_AREA)
|
|
||||||
process_face = np.expand_dims(process_face, 0)
|
|
||||||
|
|
||||||
new_face = self.encoder(process_face / 255.0)[0]
|
|
||||||
new_face = np.clip(new_face * 255, 0, 255).astype(src_face.dtype)
|
|
||||||
new_face = cv2.resize(
|
|
||||||
new_face,
|
|
||||||
(face_size - padding * 2, face_size - padding * 2),
|
|
||||||
interpolation=cv2.INTER_CUBIC)
|
|
||||||
|
|
||||||
if self.use_avg_color_adjust:
|
|
||||||
self.adjust_avg_color(old_face, new_face)
|
|
||||||
if self.use_smooth_mask:
|
|
||||||
self.smooth_mask(old_face, new_face)
|
|
||||||
|
|
||||||
new_face = self.superpose(src_face, new_face, crop)
|
|
||||||
new_image = frame.copy()
|
|
||||||
|
|
||||||
if self.draw_transparent:
|
|
||||||
new_image, new_face = self.convert_transparent(new_image,
|
|
||||||
new_face)
|
|
||||||
|
|
||||||
cv2.warpAffine(
|
|
||||||
new_face,
|
|
||||||
detected_face.adjusted_matrix,
|
|
||||||
(detected_face.frame_dims[1], detected_face.frame_dims[0]),
|
|
||||||
new_image,
|
|
||||||
flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC,
|
|
||||||
borderMode=cv2.BORDER_TRANSPARENT)
|
|
||||||
return new_image
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def adjust_avg_color(old_face, new_face):
|
|
||||||
""" Perform average color adjustment """
|
|
||||||
for i in range(new_face.shape[-1]):
|
|
||||||
old_avg = old_face[:, :, i].mean()
|
|
||||||
new_avg = new_face[:, :, i].mean()
|
|
||||||
diff_int = (int)(old_avg - new_avg)
|
|
||||||
for int_h in range(new_face.shape[0]):
|
|
||||||
for int_w in range(new_face.shape[1]):
|
|
||||||
temp = (new_face[int_h, int_w, i] + diff_int)
|
|
||||||
if temp < 0:
|
|
||||||
new_face[int_h, int_w, i] = 0
|
|
||||||
elif temp > 255:
|
|
||||||
new_face[int_h, int_w, i] = 255
|
|
||||||
else:
|
|
||||||
new_face[int_h, int_w, i] = temp
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def smooth_mask(old_face, new_face):
|
|
||||||
""" Smooth the mask """
|
|
||||||
width, height, _ = new_face.shape
|
|
||||||
crop = slice(0, width)
|
|
||||||
mask = np.zeros_like(new_face)
|
|
||||||
mask[height // 15:-height // 15, width // 15:-width // 15, :] = 255
|
|
||||||
mask = cv2.GaussianBlur(mask, # pylint: disable=no-member
|
|
||||||
(15, 15),
|
|
||||||
10)
|
|
||||||
new_face[crop, crop] = (mask / 255 * new_face +
|
|
||||||
(1 - mask / 255) * old_face)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def superpose(src_face, new_face, crop):
|
|
||||||
""" Crop Face """
|
|
||||||
new_image = src_face.copy()
|
|
||||||
new_image[crop, crop] = new_face
|
|
||||||
return new_image
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_transparent(image, new_face):
|
|
||||||
""" Add alpha channels to images and change to
|
|
||||||
transparent background """
|
|
||||||
height, width = image.shape[:2]
|
|
||||||
image = np.zeros((height, width, 4), dtype=np.uint8)
|
|
||||||
new_face = add_alpha_channel(new_face, 100)
|
|
||||||
return image, new_face
|
|
|
@ -1,230 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
""" Masked converter for faceswap.py
|
|
||||||
Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955
|
|
||||||
found on https://www.reddit.com/r/deepfakes/ """
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import cv2
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
from lib.aligner import get_align_mat
|
|
||||||
from lib.utils import add_alpha_channel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class Convert():
|
|
||||||
def __init__(self, encoder, trainer,
|
|
||||||
blur_size=2, seamless_clone=False, mask_type="facehullandrect",
|
|
||||||
erosion_kernel_size=None, match_histogram=False, sharpen_image=None,
|
|
||||||
draw_transparent=False, **kwargs):
|
|
||||||
self.encoder = encoder
|
|
||||||
self.trainer = trainer
|
|
||||||
self.erosion_kernel = None
|
|
||||||
self.erosion_kernel_size = erosion_kernel_size
|
|
||||||
if erosion_kernel_size is not None:
|
|
||||||
if erosion_kernel_size > 0:
|
|
||||||
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
|
||||||
(erosion_kernel_size,
|
|
||||||
erosion_kernel_size))
|
|
||||||
elif erosion_kernel_size < 0:
|
|
||||||
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
|
||||||
(abs(erosion_kernel_size),
|
|
||||||
abs(erosion_kernel_size)))
|
|
||||||
self.blur_size = blur_size
|
|
||||||
self.seamless_clone = seamless_clone
|
|
||||||
self.sharpen_image = sharpen_image
|
|
||||||
self.match_histogram = match_histogram
|
|
||||||
self.mask_type = mask_type.lower() # Choose in 'FaceHullAndRect', 'FaceHull', 'Rect'
|
|
||||||
self.draw_transparent = draw_transparent
|
|
||||||
|
|
||||||
def patch_image(self, image, face_detected, size):
|
|
||||||
|
|
||||||
image_size = image.shape[1], image.shape[0]
|
|
||||||
|
|
||||||
mat = numpy.array(get_align_mat(face_detected,
|
|
||||||
size,
|
|
||||||
should_align_eyes=False)).reshape(2, 3)
|
|
||||||
|
|
||||||
if "GAN" not in self.trainer:
|
|
||||||
mat = mat * size
|
|
||||||
else:
|
|
||||||
padding = int(48/256*size)
|
|
||||||
mat = mat * (size - 2 * padding)
|
|
||||||
mat[:, 2] += padding
|
|
||||||
|
|
||||||
new_face = self.get_new_face(image, mat, size)
|
|
||||||
|
|
||||||
image_mask = self.get_image_mask(image,
|
|
||||||
new_face,
|
|
||||||
face_detected.landmarks_as_xy,
|
|
||||||
mat,
|
|
||||||
image_size)
|
|
||||||
|
|
||||||
return self.apply_new_face(image, new_face, image_mask, mat, image_size, size)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_transparent(image, new_face, image_mask, image_size):
|
|
||||||
""" Add alpha channels to images and change to
|
|
||||||
transparent background """
|
|
||||||
image = numpy.zeros((image_size[1], image_size[0], 4),
|
|
||||||
dtype=numpy.uint8)
|
|
||||||
image_mask = add_alpha_channel(image_mask, 100)
|
|
||||||
new_face = add_alpha_channel(new_face, 100)
|
|
||||||
return image, new_face, image_mask
|
|
||||||
|
|
||||||
def apply_new_face(self, image, new_face, image_mask, mat, image_size, size):
|
|
||||||
|
|
||||||
if self.draw_transparent:
|
|
||||||
image, new_face, image_mask = self.convert_transparent(image,
|
|
||||||
new_face,
|
|
||||||
image_mask,
|
|
||||||
image_size)
|
|
||||||
self.seamless_clone = False # Alpha channel not supported in seamless
|
|
||||||
base_image = numpy.copy(image)
|
|
||||||
new_image = numpy.copy(image)
|
|
||||||
|
|
||||||
cv2.warpAffine(new_face,
|
|
||||||
mat,
|
|
||||||
image_size,
|
|
||||||
new_image,
|
|
||||||
cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC,
|
|
||||||
cv2.BORDER_TRANSPARENT)
|
|
||||||
|
|
||||||
if self.sharpen_image == "bsharpen":
|
|
||||||
# Sharpening using filter2D
|
|
||||||
kernel = numpy.ones((3, 3)) * (-1)
|
|
||||||
kernel[1, 1] = 9
|
|
||||||
new_image = cv2.filter2D(new_image, -1, kernel)
|
|
||||||
elif self.sharpen_image == "gsharpen":
|
|
||||||
# Sharpening using Weighted Method
|
|
||||||
gaussain_blur = cv2.GaussianBlur(new_image, (0, 0), 3.0)
|
|
||||||
new_image = cv2.addWeighted(
|
|
||||||
new_image, 1.5, gaussain_blur, -0.5, 0, new_image)
|
|
||||||
|
|
||||||
outimage = None
|
|
||||||
if self.seamless_clone:
|
|
||||||
unitMask = numpy.clip(image_mask * 365, 0, 255).astype(numpy.uint8)
|
|
||||||
logger.info(unitMask.shape)
|
|
||||||
logger.info(new_image.shape)
|
|
||||||
logger.info(base_image.shape)
|
|
||||||
maxregion = numpy.argwhere(unitMask == 255)
|
|
||||||
|
|
||||||
if maxregion.size > 0:
|
|
||||||
miny, minx = maxregion.min(axis=0)[:2]
|
|
||||||
maxy, maxx = maxregion.max(axis=0)[:2]
|
|
||||||
lenx = maxx - minx
|
|
||||||
leny = maxy - miny
|
|
||||||
masky = int(minx + (lenx // 2))
|
|
||||||
maskx = int(miny + (leny // 2))
|
|
||||||
outimage = cv2.seamlessClone(new_image.astype(numpy.uint8),
|
|
||||||
base_image.astype(numpy.uint8),
|
|
||||||
unitMask,
|
|
||||||
(masky, maskx),
|
|
||||||
cv2.NORMAL_CLONE)
|
|
||||||
return outimage
|
|
||||||
|
|
||||||
foreground = cv2.multiply(image_mask, new_image.astype(float))
|
|
||||||
background = cv2.multiply(1.0 - image_mask, base_image.astype(float))
|
|
||||||
outimage = cv2.add(foreground, background)
|
|
||||||
|
|
||||||
return outimage
|
|
||||||
|
|
||||||
def hist_match(self, source, template, mask=None):
|
|
||||||
# Code borrowed from:
|
|
||||||
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
|
|
||||||
masked_source = source
|
|
||||||
masked_template = template
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
masked_source = source * mask
|
|
||||||
masked_template = template * mask
|
|
||||||
|
|
||||||
oldshape = source.shape
|
|
||||||
source = source.ravel()
|
|
||||||
template = template.ravel()
|
|
||||||
masked_source = masked_source.ravel()
|
|
||||||
masked_template = masked_template.ravel()
|
|
||||||
s_values, bin_idx, s_counts = numpy.unique(source, return_inverse=True,
|
|
||||||
return_counts=True)
|
|
||||||
t_values, t_counts = numpy.unique(template, return_counts=True)
|
|
||||||
ms_values, mbin_idx, ms_counts = numpy.unique(source, return_inverse=True,
|
|
||||||
return_counts=True)
|
|
||||||
mt_values, mt_counts = numpy.unique(template, return_counts=True)
|
|
||||||
|
|
||||||
s_quantiles = numpy.cumsum(s_counts).astype(numpy.float64)
|
|
||||||
s_quantiles /= s_quantiles[-1]
|
|
||||||
t_quantiles = numpy.cumsum(t_counts).astype(numpy.float64)
|
|
||||||
t_quantiles /= t_quantiles[-1]
|
|
||||||
interp_t_values = numpy.interp(s_quantiles, t_quantiles, t_values)
|
|
||||||
|
|
||||||
return interp_t_values[bin_idx].reshape(oldshape)
|
|
||||||
|
|
||||||
def color_hist_match(self, src_im, tar_im, mask):
|
|
||||||
matched_R = self.hist_match(src_im[:, :, 0], tar_im[:, :, 0], mask)
|
|
||||||
matched_G = self.hist_match(src_im[:, :, 1], tar_im[:, :, 1], mask)
|
|
||||||
matched_B = self.hist_match(src_im[:, :, 2], tar_im[:, :, 2], mask)
|
|
||||||
matched = numpy.stack((matched_R, matched_G, matched_B), axis=2).astype(src_im.dtype)
|
|
||||||
return matched
|
|
||||||
|
|
||||||
def get_new_face(self, image, mat, size):
|
|
||||||
face = cv2.warpAffine(image, mat, (size, size))
|
|
||||||
face = numpy.expand_dims(face, 0)
|
|
||||||
face_clipped = numpy.clip(face[0], 0, 255).astype(image.dtype)
|
|
||||||
new_face = None
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
if "GAN" not in self.trainer:
|
|
||||||
normalized_face = face / 255.0
|
|
||||||
new_face = self.encoder(normalized_face)[0]
|
|
||||||
new_face = numpy.clip(new_face * 255, 0, 255).astype(image.dtype)
|
|
||||||
else:
|
|
||||||
normalized_face = face / 255.0 * 2 - 1
|
|
||||||
fake_output = self.encoder(normalized_face)
|
|
||||||
if "128" in self.trainer: # TODO: Another hack to switch between 64 and 128
|
|
||||||
fake_output = fake_output[0]
|
|
||||||
mask = fake_output[:, :, :, :1]
|
|
||||||
new_face = fake_output[:, :, :, 1:]
|
|
||||||
new_face = mask * new_face + (1 - mask) * normalized_face
|
|
||||||
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype(image.dtype)
|
|
||||||
|
|
||||||
if self.match_histogram:
|
|
||||||
new_face = self.color_hist_match(new_face, face_clipped, mask)
|
|
||||||
|
|
||||||
return new_face
|
|
||||||
|
|
||||||
def get_image_mask(self, image, new_face, landmarks, mat, image_size):
|
|
||||||
|
|
||||||
face_mask = numpy.zeros(image.shape, dtype=float)
|
|
||||||
if 'rect' in self.mask_type:
|
|
||||||
face_src = numpy.ones(new_face.shape, dtype=float)
|
|
||||||
cv2.warpAffine(face_src,
|
|
||||||
mat,
|
|
||||||
image_size,
|
|
||||||
face_mask,
|
|
||||||
cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT)
|
|
||||||
|
|
||||||
hull_mask = numpy.zeros(image.shape, dtype=float)
|
|
||||||
if 'hull' in self.mask_type:
|
|
||||||
hull = cv2.convexHull(
|
|
||||||
numpy.array(landmarks).reshape((-1, 2)).astype(int)).flatten().reshape((-1, 2))
|
|
||||||
cv2.fillConvexPoly(hull_mask, hull, (1, 1, 1))
|
|
||||||
|
|
||||||
if self.mask_type == 'rect':
|
|
||||||
image_mask = face_mask
|
|
||||||
elif self.mask_type == 'facehull':
|
|
||||||
image_mask = hull_mask
|
|
||||||
else:
|
|
||||||
image_mask = ((face_mask*hull_mask))
|
|
||||||
|
|
||||||
if self.erosion_kernel is not None:
|
|
||||||
if self.erosion_kernel_size > 0:
|
|
||||||
image_mask = cv2.erode(image_mask, self.erosion_kernel, iterations=1)
|
|
||||||
elif self.erosion_kernel_size < 0:
|
|
||||||
dilation_kernel = abs(self.erosion_kernel)
|
|
||||||
image_mask = cv2.dilate(image_mask, dilation_kernel, iterations=1)
|
|
||||||
|
|
||||||
if self.blur_size != 0:
|
|
||||||
image_mask = cv2.blur(image_mask, (self.blur_size, self.blur_size))
|
|
||||||
|
|
||||||
return image_mask
|
|
358
plugins/convert/masked.py
Normal file
358
plugins/convert/masked.py
Normal file
|
@ -0,0 +1,358 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Masked converter for faceswap.py
|
||||||
|
Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955
|
||||||
|
found on https://www.reddit.com/r/deepfakes/ """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from lib.model.masks import dfl_full
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Convert():
|
||||||
|
""" Swap a source face with a target """
|
||||||
|
def __init__(self, encoder, model, arguments):
|
||||||
|
logger.debug("Initializing %s: (encoder: '%s', model: %s, arguments: %s",
|
||||||
|
self.__class__.__name__, encoder, model, arguments)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.args = arguments
|
||||||
|
self.input_size = model.input_shape[0]
|
||||||
|
self.training_size = model.state.training_size
|
||||||
|
self.training_coverage_ratio = model.training_opts["coverage_ratio"]
|
||||||
|
self.input_mask_shape = model.state.mask_shapes[0] if model.state.mask_shapes else None
|
||||||
|
|
||||||
|
self.crop = None
|
||||||
|
self.mask = None
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def patch_image(self, image, detected_face):
|
||||||
|
""" Patch the image """
|
||||||
|
logger.trace("Patching image")
|
||||||
|
image = image.astype('float32')
|
||||||
|
image_size = (image.shape[1], image.shape[0])
|
||||||
|
coverage = int(self.training_coverage_ratio * self.training_size)
|
||||||
|
padding = (self.training_size - coverage) // 2
|
||||||
|
logger.trace("coverage: %s, padding: %s", coverage, padding)
|
||||||
|
|
||||||
|
self.crop = slice(padding, self.training_size - padding)
|
||||||
|
if not self.mask: # Init the mask on first image
|
||||||
|
self.mask = Mask(self.args.mask_type, self.training_size, padding, self.crop)
|
||||||
|
|
||||||
|
detected_face.load_aligned(image, size=self.training_size, align_eyes=False)
|
||||||
|
new_image = self.get_new_image(image, detected_face, coverage, image_size)
|
||||||
|
image_mask = self.get_image_mask(detected_face, image_size)
|
||||||
|
patched_face = self.apply_fixes(image,
|
||||||
|
new_image,
|
||||||
|
image_mask,
|
||||||
|
image_size)
|
||||||
|
|
||||||
|
logger.trace("Patched image")
|
||||||
|
return patched_face
|
||||||
|
|
||||||
|
def get_new_image(self, image, detected_face, coverage, image_size):
|
||||||
|
""" Get the new face from the predictor """
|
||||||
|
logger.trace("coverage: %s", coverage)
|
||||||
|
src_face = detected_face.aligned_face
|
||||||
|
coverage_face = src_face[self.crop, self.crop]
|
||||||
|
coverage_face = cv2.resize(coverage_face, # pylint: disable=no-member
|
||||||
|
(self.input_size, self.input_size),
|
||||||
|
interpolation=cv2.INTER_AREA) # pylint: disable=no-member
|
||||||
|
coverage_face = np.expand_dims(coverage_face, 0)
|
||||||
|
np.clip(coverage_face / 255.0, 0.0, 1.0, out=coverage_face)
|
||||||
|
|
||||||
|
if self.input_mask_shape:
|
||||||
|
mask = np.zeros(self.input_mask_shape, np.float32)
|
||||||
|
mask = np.expand_dims(mask, 0)
|
||||||
|
feed = [coverage_face, mask]
|
||||||
|
else:
|
||||||
|
feed = [coverage_face]
|
||||||
|
logger.trace("Input shapes: %s", [item.shape for item in feed])
|
||||||
|
new_face = self.encoder(feed)[0]
|
||||||
|
new_face = new_face.squeeze()
|
||||||
|
logger.trace("Output shape: %s", new_face.shape)
|
||||||
|
|
||||||
|
new_face = cv2.resize(new_face, # pylint: disable=no-member
|
||||||
|
(coverage, coverage),
|
||||||
|
interpolation=cv2.INTER_CUBIC) # pylint: disable=no-member
|
||||||
|
np.clip(new_face * 255.0, 0.0, 255.0, out=new_face)
|
||||||
|
src_face[self.crop, self.crop] = new_face
|
||||||
|
background = image.copy()
|
||||||
|
interpolator = detected_face.adjusted_interpolators[1]
|
||||||
|
new_image = cv2.warpAffine( # pylint: disable=no-member
|
||||||
|
src_face,
|
||||||
|
detected_face.adjusted_matrix,
|
||||||
|
image_size,
|
||||||
|
background,
|
||||||
|
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
|
||||||
|
borderMode=cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
def get_image_mask(self, detected_face, image_size):
|
||||||
|
""" Get the image mask """
|
||||||
|
mask = self.mask.get_mask(detected_face, image_size)
|
||||||
|
if self.args.erosion_size != 0:
|
||||||
|
kwargs = {'src': mask,
|
||||||
|
'kernel': self.set_erosion_kernel(mask),
|
||||||
|
'iterations': 1}
|
||||||
|
if self.args.erosion_size > 0:
|
||||||
|
mask = cv2.erode(**kwargs) # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
mask = cv2.dilate(**kwargs) # pylint: disable=no-member
|
||||||
|
|
||||||
|
if self.args.blur_size != 0:
|
||||||
|
blur_size = self.set_blur_size(mask)
|
||||||
|
mask = cv2.blur(mask, (blur_size, blur_size)) # pylint: disable=no-member
|
||||||
|
|
||||||
|
return np.clip(mask, 0.0, 1.0, out=mask)
|
||||||
|
|
||||||
|
def set_erosion_kernel(self, mask):
|
||||||
|
""" Set the erosion kernel """
|
||||||
|
erosion_ratio = self.args.erosion_size / 100
|
||||||
|
mask_radius = np.sqrt(np.sum(mask)) / 2
|
||||||
|
percent_erode = max(1, int(abs(erosion_ratio * mask_radius)))
|
||||||
|
erosion_kernel = cv2.getStructuringElement( # pylint: disable=no-member
|
||||||
|
cv2.MORPH_ELLIPSE, # pylint: disable=no-member
|
||||||
|
(percent_erode, percent_erode))
|
||||||
|
logger.trace("erosion_kernel shape: %s", erosion_kernel.shape)
|
||||||
|
return erosion_kernel
|
||||||
|
|
||||||
|
def set_blur_size(self, mask):
|
||||||
|
""" Set the blur size to absolute or percentage """
|
||||||
|
blur_ratio = self.args.blur_size / 100
|
||||||
|
mask_radius = np.sqrt(np.sum(mask)) / 2
|
||||||
|
blur_size = int(max(1, blur_ratio * mask_radius))
|
||||||
|
logger.trace("blur_size: %s", blur_size)
|
||||||
|
return blur_size
|
||||||
|
|
||||||
|
def apply_fixes(self, frame, new_image, image_mask, image_size):
|
||||||
|
""" Apply fixes """
|
||||||
|
masked = new_image # * image_mask
|
||||||
|
|
||||||
|
if self.args.draw_transparent:
|
||||||
|
alpha = np.full((image_size[1], image_size[0], 1), 255.0, dtype='float32')
|
||||||
|
new_image = np.concatenate(new_image, alpha, axis=2)
|
||||||
|
image_mask = np.concatenate(image_mask, alpha, axis=2)
|
||||||
|
frame = np.concatenate(frame, alpha, axis=2)
|
||||||
|
|
||||||
|
if self.args.sharpen_image is not None:
|
||||||
|
np.clip(masked, 0.0, 255.0, out=masked)
|
||||||
|
if self.args.sharpen_image == "box_filter":
|
||||||
|
kernel = np.ones((3, 3)) * (-1)
|
||||||
|
kernel[1, 1] = 9
|
||||||
|
masked = cv2.filter2D(masked, -1, kernel) # pylint: disable=no-member
|
||||||
|
elif self.args.sharpen_image == "gaussian_filter":
|
||||||
|
blur = cv2.GaussianBlur(masked, (0, 0), 3.0) # pylint: disable=no-member
|
||||||
|
masked = cv2.addWeighted(masked, # pylint: disable=no-member
|
||||||
|
1.5,
|
||||||
|
blur,
|
||||||
|
-0.5,
|
||||||
|
0,
|
||||||
|
masked)
|
||||||
|
|
||||||
|
if self.args.avg_color_adjust:
|
||||||
|
for _ in [0, 1]:
|
||||||
|
np.clip(masked, 0.0, 255.0, out=masked)
|
||||||
|
diff = frame - masked
|
||||||
|
avg_diff = np.sum(diff * image_mask, axis=(0, 1))
|
||||||
|
adjustment = avg_diff / np.sum(image_mask, axis=(0, 1))
|
||||||
|
masked = masked + adjustment
|
||||||
|
|
||||||
|
if self.args.match_histogram:
|
||||||
|
np.clip(masked, 0.0, 255.0, out=masked)
|
||||||
|
masked = self.color_hist_match(masked, frame, image_mask)
|
||||||
|
|
||||||
|
if self.args.seamless_clone and not self.args.draw_transparent:
|
||||||
|
h, w, _ = frame.shape
|
||||||
|
h = h // 2
|
||||||
|
w = w // 2
|
||||||
|
|
||||||
|
y_indices, x_indices, _ = np.nonzero(image_mask)
|
||||||
|
y_crop = slice(np.min(y_indices), np.max(y_indices))
|
||||||
|
x_crop = slice(np.min(x_indices), np.max(x_indices))
|
||||||
|
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2) + h)
|
||||||
|
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2) + w)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# test with average of centroid rather than the h /2 , w/2 center
|
||||||
|
y_center = int(np.rint(np.average(y_indices) + h)
|
||||||
|
x_center = int(np.rint(np.average(x_indices) + w)
|
||||||
|
'''
|
||||||
|
|
||||||
|
insertion = np.rint(masked[y_crop, x_crop, :]).astype('uint8')
|
||||||
|
insertion_mask = image_mask[y_crop, x_crop, :]
|
||||||
|
insertion_mask[insertion_mask != 0] = 255
|
||||||
|
insertion_mask = insertion_mask.astype('uint8')
|
||||||
|
|
||||||
|
prior = np.pad(frame, ((h, h), (w, w), (0, 0)), 'constant').astype('uint8')
|
||||||
|
|
||||||
|
blended = cv2.seamlessClone(insertion, # pylint: disable=no-member
|
||||||
|
prior,
|
||||||
|
insertion_mask,
|
||||||
|
(x_center, y_center),
|
||||||
|
cv2.NORMAL_CLONE) # pylint: disable=no-member
|
||||||
|
blended = blended[h:-h, w:-w, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
foreground = masked * image_mask
|
||||||
|
background = frame * (1.0 - image_mask)
|
||||||
|
blended = foreground + background
|
||||||
|
|
||||||
|
np.clip(blended, 0.0, 255.0, out=blended)
|
||||||
|
|
||||||
|
return np.rint(blended).astype('uint8')
|
||||||
|
|
||||||
|
def color_hist_match(self, new, frame, image_mask):
|
||||||
|
for channel in [0, 1, 2]:
|
||||||
|
new[:, :, channel] = self.hist_match(new[:, :, channel],
|
||||||
|
frame[:, :, channel],
|
||||||
|
image_mask[:, :, channel])
|
||||||
|
# source = np.stack([self.hist_match(source[:,:,c], target[:,:,c],image_mask[:,:,c])
|
||||||
|
# for c in [0,1,2]],
|
||||||
|
# axis=2)
|
||||||
|
return new
|
||||||
|
|
||||||
|
def hist_match(self, new, frame, image_mask):
|
||||||
|
|
||||||
|
mask_indices = np.nonzero(image_mask)
|
||||||
|
if len(mask_indices[0]) == 0:
|
||||||
|
return new
|
||||||
|
|
||||||
|
m_new = new[mask_indices].ravel()
|
||||||
|
m_frame = frame[mask_indices].ravel()
|
||||||
|
s_values, bin_idx, s_counts = np.unique(m_new, return_inverse=True, return_counts=True)
|
||||||
|
t_values, t_counts = np.unique(m_frame, return_counts=True)
|
||||||
|
s_quants = np.cumsum(s_counts, dtype='float32')
|
||||||
|
t_quants = np.cumsum(t_counts, dtype='float32')
|
||||||
|
s_quants /= s_quants[-1] # cdf
|
||||||
|
t_quants /= t_quants[-1] # cdf
|
||||||
|
interp_s_values = np.interp(s_quants, t_quants, t_values)
|
||||||
|
new.put(mask_indices, interp_s_values[bin_idx])
|
||||||
|
|
||||||
|
'''
|
||||||
|
bins = np.arange(256)
|
||||||
|
template_CDF, _ = np.histogram(m_frame, bins=bins, density=True)
|
||||||
|
flat_new_image = np.interp(m_source.ravel(), bins[:-1], template_CDF) * 255.0
|
||||||
|
return flat_new_image.reshape(m_source.shape) * 255.0
|
||||||
|
'''
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
class Mask():
|
||||||
|
""" Return the requested mask """
|
||||||
|
|
||||||
|
def __init__(self, mask_type, training_size, padding, crop):
|
||||||
|
""" Set requested mask """
|
||||||
|
logger.debug("Initializing %s: (mask_type: '%s', training_size: %s, padding: %s)",
|
||||||
|
self.__class__.__name__, mask_type, training_size, padding)
|
||||||
|
|
||||||
|
self.training_size = training_size
|
||||||
|
self.padding = padding
|
||||||
|
self.mask_type = mask_type
|
||||||
|
self.crop = crop
|
||||||
|
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def get_mask(self, detected_face, image_size):
|
||||||
|
""" Return a face mask """
|
||||||
|
kwargs = {"matrix": detected_face.adjusted_matrix,
|
||||||
|
"interpolators": detected_face.adjusted_interpolators,
|
||||||
|
"landmarks": detected_face.landmarks_as_xy,
|
||||||
|
"image_size": image_size}
|
||||||
|
logger.trace("kwargs: %s", kwargs)
|
||||||
|
mask = getattr(self, self.mask_type)(**kwargs)
|
||||||
|
mask = self.finalize_mask(mask)
|
||||||
|
logger.trace("mask shape: %s", mask.shape)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def cnn(self, **kwargs):
|
||||||
|
""" CNN Mask """
|
||||||
|
# Insert FCN-VGG16 segmentation mask model here
|
||||||
|
logger.info("cnn not yet implemented, using facehull instead")
|
||||||
|
return self.facehull(**kwargs)
|
||||||
|
|
||||||
|
def smoothed(self, **kwargs):
|
||||||
|
""" Smoothed Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
interpolator = kwargs["interpolators"][1]
|
||||||
|
ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32')
|
||||||
|
# area = self.padding + (self.training_size - 2 * self.padding) // 15
|
||||||
|
# central_core = slice(area, -area)
|
||||||
|
ones[self.crop, self.crop] = 1.0
|
||||||
|
ones = cv2.GaussianBlur(ones, (25, 25), 10) # pylint: disable=no-member
|
||||||
|
|
||||||
|
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
|
||||||
|
cv2.warpAffine(ones, # pylint: disable=no-member
|
||||||
|
kwargs["matrix"],
|
||||||
|
kwargs["image_size"],
|
||||||
|
mask,
|
||||||
|
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
|
||||||
|
borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member
|
||||||
|
borderValue=0.0)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def rect(self, **kwargs):
|
||||||
|
""" Rect Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
interpolator = kwargs["interpolators"][1]
|
||||||
|
ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32')
|
||||||
|
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
|
||||||
|
# central_core = slice(self.padding, -self.padding)
|
||||||
|
ones[self.crop, self.crop] = 1.0
|
||||||
|
cv2.warpAffine(ones, # pylint: disable=no-member
|
||||||
|
kwargs["matrix"],
|
||||||
|
kwargs["image_size"],
|
||||||
|
mask,
|
||||||
|
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
|
||||||
|
borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member
|
||||||
|
borderValue=0.0)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def dfl(self, **kwargs):
|
||||||
|
""" DFaker Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
dummy = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
|
||||||
|
mask = dfl_full(kwargs["landmarks"], dummy, channels=3)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def facehull(self, **kwargs):
|
||||||
|
""" Facehull Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
|
||||||
|
hull = cv2.convexHull( # pylint: disable=no-member
|
||||||
|
np.array(kwargs["landmarks"]).reshape((-1, 2)))
|
||||||
|
cv2.fillConvexPoly(mask, # pylint: disable=no-member
|
||||||
|
hull,
|
||||||
|
(1.0, 1.0, 1.0),
|
||||||
|
lineType=cv2.LINE_AA) # pylint: disable=no-member
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def facehull_rect(self, **kwargs):
|
||||||
|
""" Facehull Rect Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
mask = self.rect(**kwargs)
|
||||||
|
hull_mask = self.facehull(**kwargs)
|
||||||
|
mask *= hull_mask
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def ellipse(self, **kwargs):
|
||||||
|
""" Ellipse Mask """
|
||||||
|
logger.trace("Getting mask")
|
||||||
|
mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32')
|
||||||
|
ell = cv2.fitEllipse( # pylint: disable=no-member
|
||||||
|
np.array(kwargs["landmarks"]).reshape((-1, 2)))
|
||||||
|
cv2.ellipse(mask, # pylint: disable=no-member
|
||||||
|
box=ell,
|
||||||
|
color=(1.0, 1.0, 1.0),
|
||||||
|
thickness=-1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def finalize_mask(mask):
|
||||||
|
""" Finalize the mask """
|
||||||
|
logger.trace("Finalizing mask")
|
||||||
|
np.nan_to_num(mask, copy=False)
|
||||||
|
np.clip(mask, 0.0, 1.0, out=mask)
|
||||||
|
return mask
|
48
plugins/extract/_config.py
Normal file
48
plugins/extract/_config.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Default configurations for extract """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lib.config import FaceswapConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Config(FaceswapConfig):
|
||||||
|
""" Config File for Models """
|
||||||
|
|
||||||
|
def set_defaults(self):
|
||||||
|
""" Set the default values for config """
|
||||||
|
logger.debug("Setting defaults")
|
||||||
|
|
||||||
|
# << GLOBAL OPTIONS >> #
|
||||||
|
# section = "global"
|
||||||
|
# self.add_section(title=section,
|
||||||
|
# info="Options that apply to all models")
|
||||||
|
|
||||||
|
# << MTCNN DETECTOR OPTIONS >> #
|
||||||
|
section = "detect.mtcnn"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="MTCNN Detector options")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="minsize", datatype=int, default=20, rounding=10,
|
||||||
|
min_max=(20, 1000),
|
||||||
|
info="The minimum size of a face (in pixels) to be accepted as a positive match.\n"
|
||||||
|
"Lower values use significantly more VRAM and will detect more false positives")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="threshold_1", datatype=float, default=0.6, rounding=2,
|
||||||
|
min_max=(0.1, 0.9),
|
||||||
|
info="First stage threshold for face detection. This stage obtains face candidates")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="threshold_2", datatype=float, default=0.7, rounding=2,
|
||||||
|
min_max=(0.1, 0.9),
|
||||||
|
info="Second stage threshold for face detection. This stage refines face candidates")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="threshold_3", datatype=float, default=0.7, rounding=2,
|
||||||
|
min_max=(0.1, 0.9),
|
||||||
|
info="Third stage threshold for face detection. This stage further refines face "
|
||||||
|
"candidates")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="scalefactor", datatype=float, default=0.709, rounding=3,
|
||||||
|
min_max=(0.1, 0.9),
|
||||||
|
info="The scale factor for the image pyramid")
|
|
@ -22,14 +22,21 @@ from math import sqrt
|
||||||
|
|
||||||
from lib.gpu_stats import GPUStats
|
from lib.gpu_stats import GPUStats
|
||||||
from lib.utils import rotate_landmarks
|
from lib.utils import rotate_landmarks
|
||||||
|
from plugins.extract._config import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(plugin_name):
|
||||||
|
""" Return the config for the requested model """
|
||||||
|
return Config(plugin_name).config_dict
|
||||||
|
|
||||||
|
|
||||||
class Detector():
|
class Detector():
|
||||||
""" Detector object """
|
""" Detector object """
|
||||||
def __init__(self, loglevel, rotation=None):
|
def __init__(self, loglevel, rotation=None):
|
||||||
logger.debug("Initializing %s: (rotation: %s)", self.__class__.__name__, rotation)
|
logger.debug("Initializing %s: (rotation: %s)", self.__class__.__name__, rotation)
|
||||||
|
self.config = get_config(".".join(self.__module__.split(".")[-2:]))
|
||||||
self.loglevel = loglevel
|
self.loglevel = loglevel
|
||||||
self.cachepath = os.path.join(os.path.dirname(__file__), ".cache")
|
self.cachepath = os.path.join(os.path.dirname(__file__), ".cache")
|
||||||
self.rotation = self.get_rotation_angles(rotation)
|
self.rotation = self.get_rotation_angles(rotation)
|
||||||
|
@ -107,6 +114,7 @@ class Detector():
|
||||||
logger.exception("Traceback:")
|
logger.exception("Traceback:")
|
||||||
tb_buffer = StringIO()
|
tb_buffer = StringIO()
|
||||||
traceback.print_exc(file=tb_buffer)
|
traceback.print_exc(file=tb_buffer)
|
||||||
|
logger.trace(tb_buffer.getvalue())
|
||||||
exception = {"exception": (os.getpid(), tb_buffer)}
|
exception = {"exception": (os.getpid(), tb_buffer)}
|
||||||
self.queues["out"].put(exception)
|
self.queues["out"].put(exception)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
|
@ -111,7 +111,7 @@ class Detect(Detector):
|
||||||
|
|
||||||
def detect_batch(self, detect_images, disable_message=False):
|
def detect_batch(self, detect_images, disable_message=False):
|
||||||
""" Pass the batch through detector for consistently sized images
|
""" Pass the batch through detector for consistently sized images
|
||||||
or each image seperately for inconsitently sized images """
|
or each image separately for inconsitently sized images """
|
||||||
logger.trace("Detecting Batch")
|
logger.trace("Detecting Batch")
|
||||||
can_batch = self.check_batch_dims(detect_images)
|
can_batch = self.check_batch_dims(detect_images)
|
||||||
if can_batch:
|
if can_batch:
|
||||||
|
|
|
@ -30,21 +30,24 @@ class Detect(Detector):
|
||||||
""" MTCNN detector for face recognition """
|
""" MTCNN detector for face recognition """
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.kwargs = None
|
self.kwargs = self.validate_kwargs()
|
||||||
self.name = "mtcnn"
|
self.name = "mtcnn"
|
||||||
self.target = 2073600 # Uses approx 1.30 GB of VRAM
|
self.target = 2073600 # Uses approx 1.30 GB of VRAM
|
||||||
self.vram = 1408
|
self.vram = 1408
|
||||||
|
|
||||||
@staticmethod
|
def validate_kwargs(self):
|
||||||
def validate_kwargs(kwargs):
|
""" Validate that config options are correct. If not reset to default """
|
||||||
""" Validate that cli kwargs are correct. If not reset to default """
|
|
||||||
valid = True
|
valid = True
|
||||||
if kwargs['minsize'] < 10:
|
threshold = [self.config["threshold_1"],
|
||||||
|
self.config["threshold_2"],
|
||||||
|
self.config["threshold_3"]]
|
||||||
|
kwargs = {"minsize": self.config["minsize"],
|
||||||
|
"threshold": threshold,
|
||||||
|
"factor": self.config["scalefactor"]}
|
||||||
|
|
||||||
|
if kwargs["minsize"] < 10:
|
||||||
valid = False
|
valid = False
|
||||||
elif len(kwargs['threshold']) != 3:
|
elif not all(0.0 < threshold <= 1.0 for threshold in kwargs['threshold']):
|
||||||
valid = False
|
|
||||||
elif not all(0.0 < threshold < 1.0
|
|
||||||
for threshold in kwargs['threshold']):
|
|
||||||
valid = False
|
valid = False
|
||||||
elif not 0.0 < kwargs['factor'] < 1.0:
|
elif not 0.0 < kwargs['factor'] < 1.0:
|
||||||
valid = False
|
valid = False
|
||||||
|
@ -53,7 +56,7 @@ class Detect(Detector):
|
||||||
kwargs = {"minsize": 20, # minimum size of face
|
kwargs = {"minsize": 20, # minimum size of face
|
||||||
"threshold": [0.6, 0.7, 0.7], # three steps threshold
|
"threshold": [0.6, 0.7, 0.7], # three steps threshold
|
||||||
"factor": 0.709} # scale factor
|
"factor": 0.709} # scale factor
|
||||||
logger.warning("Invalid MTCNN arguments received. Running with defaults")
|
logger.warning("Invalid MTCNN options in config. Running with defaults")
|
||||||
logger.debug("Using mtcnn kwargs: %s", kwargs)
|
logger.debug("Using mtcnn kwargs: %s", kwargs)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
@ -72,7 +75,6 @@ class Detect(Detector):
|
||||||
super().initialize(*args, **kwargs)
|
super().initialize(*args, **kwargs)
|
||||||
logger.info("Initializing MTCNN Detector...")
|
logger.info("Initializing MTCNN Detector...")
|
||||||
is_gpu = False
|
is_gpu = False
|
||||||
self.kwargs = kwargs["mtcnn_kwargs"]
|
|
||||||
|
|
||||||
# Must import tensorflow inside the spawned process
|
# Must import tensorflow inside the spawned process
|
||||||
# for Windows machines
|
# for Windows machines
|
||||||
|
|
|
@ -1,187 +0,0 @@
|
||||||
# Based on the https://github.com/shaoanlu/faceswap-GAN repo (master/temp/faceswap_GAN_keras.ipynb)
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from keras.models import Model
|
|
||||||
from keras.layers import *
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.activations import relu
|
|
||||||
from keras.initializers import RandomNormal
|
|
||||||
from keras.applications import *
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
from .instance_normalization import InstanceNormalization
|
|
||||||
from lib.utils import backup_file
|
|
||||||
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
hdf = {'netGAH5': 'netGA_GAN.h5',
|
|
||||||
'netGBH5': 'netGB_GAN.h5',
|
|
||||||
'netDAH5': 'netDA_GAN.h5',
|
|
||||||
'netDBH5': 'netDB_GAN.h5'}
|
|
||||||
|
|
||||||
def __conv_init(a):
|
|
||||||
logger.info("conv_init %s", a)
|
|
||||||
k = RandomNormal(0, 0.02)(a) # for convolution kernel
|
|
||||||
k.conv_weight = True
|
|
||||||
return k
|
|
||||||
|
|
||||||
#def batchnorm():
|
|
||||||
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
|
|
||||||
|
|
||||||
def inst_norm():
|
|
||||||
return InstanceNormalization()
|
|
||||||
|
|
||||||
conv_init = RandomNormal(0, 0.02)
|
|
||||||
gamma_init = RandomNormal(1., 0.02) # for batch normalization
|
|
||||||
|
|
||||||
class GANModel():
|
|
||||||
img_size = 64
|
|
||||||
channels = 3
|
|
||||||
img_shape = (img_size, img_size, channels)
|
|
||||||
encoded_dim = 1024
|
|
||||||
nc_in = 3 # number of input channels of generators
|
|
||||||
nc_D_inp = 6 # number of input channels of discriminators
|
|
||||||
|
|
||||||
def __init__(self, model_dir, gpus):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.gpus = gpus
|
|
||||||
|
|
||||||
optimizer = Adam(1e-4, 0.5)
|
|
||||||
|
|
||||||
# Build and compile the discriminator
|
|
||||||
self.netDA, self.netDB = self.build_discriminator()
|
|
||||||
|
|
||||||
# Build and compile the generator
|
|
||||||
self.netGA, self.netGB = self.build_generator()
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
predictor = self.netGB if not swap else self.netGA
|
|
||||||
return lambda img: predictor.predict(img)
|
|
||||||
|
|
||||||
def build_generator(self):
|
|
||||||
|
|
||||||
def conv_block(input_tensor, f):
|
|
||||||
x = input_tensor
|
|
||||||
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
x = Activation("relu")(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def res_block(input_tensor, f):
|
|
||||||
x = input_tensor
|
|
||||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
x = add([x, input_tensor])
|
|
||||||
x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def upscale_ps(filters, use_instance_norm=True):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder(nc_in=3, input_size=64):
|
|
||||||
inp = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
|
|
||||||
x = conv_block(x,128)
|
|
||||||
x = conv_block(x,256)
|
|
||||||
x = conv_block(x,512)
|
|
||||||
x = conv_block(x,1024)
|
|
||||||
x = Dense(1024)(Flatten()(x))
|
|
||||||
x = Dense(4*4*1024)(x)
|
|
||||||
x = Reshape((4, 4, 1024))(x)
|
|
||||||
out = upscale_ps(512)(x)
|
|
||||||
return Model(inputs=inp, outputs=out)
|
|
||||||
|
|
||||||
def Decoder_ps(nc_in=512, input_size=8):
|
|
||||||
input_ = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
x = input_
|
|
||||||
x = upscale_ps(256)(x)
|
|
||||||
x = upscale_ps(128)(x)
|
|
||||||
x = upscale_ps(64)(x)
|
|
||||||
x = res_block(x, 64)
|
|
||||||
x = res_block(x, 64)
|
|
||||||
#x = Conv2D(4, kernel_size=5, padding='same')(x)
|
|
||||||
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
|
|
||||||
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
|
|
||||||
out = concatenate([alpha, rgb])
|
|
||||||
return Model(input_, out )
|
|
||||||
|
|
||||||
encoder = Encoder()
|
|
||||||
decoder_A = Decoder_ps()
|
|
||||||
decoder_B = Decoder_ps()
|
|
||||||
x = Input(shape=self.img_shape)
|
|
||||||
netGA = Model(x, decoder_A(encoder(x)))
|
|
||||||
netGB = Model(x, decoder_B(encoder(x)))
|
|
||||||
|
|
||||||
self.netGA_sm = netGA
|
|
||||||
self.netGB_sm = netGB
|
|
||||||
|
|
||||||
try:
|
|
||||||
netGA.load_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
netGB.load_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
logger.info("Generator models loaded.")
|
|
||||||
except:
|
|
||||||
logger.info("Generator weights files not found.")
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
netGA = multi_gpu_model( self.netGA_sm , self.gpus)
|
|
||||||
netGB = multi_gpu_model( self.netGB_sm , self.gpus)
|
|
||||||
|
|
||||||
return netGA, netGB
|
|
||||||
|
|
||||||
def build_discriminator(self):
|
|
||||||
def conv_block_d(input_tensor, f, use_instance_norm=True):
|
|
||||||
x = input_tensor
|
|
||||||
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def Discriminator(nc_in, input_size=64):
|
|
||||||
inp = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
#x = GaussianNoise(0.05)(inp)
|
|
||||||
x = conv_block_d(inp, 64, False)
|
|
||||||
x = conv_block_d(x, 128, False)
|
|
||||||
x = conv_block_d(x, 256, False)
|
|
||||||
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
|
|
||||||
return Model(inputs=[inp], outputs=out)
|
|
||||||
|
|
||||||
netDA = Discriminator(self.nc_D_inp)
|
|
||||||
netDB = Discriminator(self.nc_D_inp)
|
|
||||||
try:
|
|
||||||
netDA.load_weights(str(self.model_dir / hdf['netDAH5']))
|
|
||||||
netDB.load_weights(str(self.model_dir / hdf['netDBH5']))
|
|
||||||
logger.info("Discriminator models loaded.")
|
|
||||||
except:
|
|
||||||
logger.info("Discriminator weights files not found.")
|
|
||||||
pass
|
|
||||||
return netDA, netDB
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
if swapped:
|
|
||||||
logger.warning("swapping not supported on GAN")
|
|
||||||
# TODO load is done in __init__ => look how to swap if possible
|
|
||||||
return True
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
else:
|
|
||||||
self.netGA.save_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
self.netGB.save_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
self.netDA.save_weights(str(self.model_dir / hdf['netDAH5']))
|
|
||||||
self.netDB.save_weights(str(self.model_dir / hdf['netDBH5']))
|
|
||||||
logger.info("Models saved.")
|
|
|
@ -1,260 +0,0 @@
|
||||||
import time
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from keras.layers import *
|
|
||||||
from tensorflow.contrib.distributions import Beta
|
|
||||||
import tensorflow as tf
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
from keras import backend as K
|
|
||||||
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
class GANTrainingDataGenerator(TrainingDataGenerator):
|
|
||||||
def __init__(self, random_transform_args, coverage, scale, zoom):
|
|
||||||
super().__init__(random_transform_args, coverage, scale, zoom)
|
|
||||||
|
|
||||||
def color_adjust(self, img):
|
|
||||||
return img / 255.0 * 2 - 1
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 20,
|
|
||||||
'zoom_range': 0.1,
|
|
||||||
'shift_range': 0.05,
|
|
||||||
'random_flip': 0.5,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
|
|
||||||
K.set_learning_phase(1)
|
|
||||||
|
|
||||||
assert batch_size % 2 == 0, "batch_size must be an even number"
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.use_lsgan = True
|
|
||||||
self.use_mixup = True
|
|
||||||
self.mixup_alpha = 0.2
|
|
||||||
self.use_perceptual_loss = perceptual_loss
|
|
||||||
self.use_instancenorm = False
|
|
||||||
|
|
||||||
self.lrD = 1e-4 # Discriminator learning rate
|
|
||||||
self.lrG = 1e-4 # Generator learning rate
|
|
||||||
|
|
||||||
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 1)
|
|
||||||
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
|
|
||||||
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
|
|
||||||
|
|
||||||
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
|
||||||
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
distorted_A, fake_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
|
|
||||||
distorted_B, fake_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
|
|
||||||
real_A = Input(shape=self.model.img_shape)
|
|
||||||
real_B = Input(shape=self.model.img_shape)
|
|
||||||
|
|
||||||
if self.use_lsgan:
|
|
||||||
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
|
|
||||||
else:
|
|
||||||
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
|
|
||||||
|
|
||||||
# ========== Define Perceptual Loss Model==========
|
|
||||||
if self.use_perceptual_loss:
|
|
||||||
from keras.models import Model
|
|
||||||
from keras_vggface.vggface import VGGFace
|
|
||||||
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
|
|
||||||
vggface.trainable = False
|
|
||||||
out_size55 = vggface.layers[36].output
|
|
||||||
out_size28 = vggface.layers[78].output
|
|
||||||
out_size7 = vggface.layers[-2].output
|
|
||||||
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
|
|
||||||
vggface_feat.trainable = False
|
|
||||||
else:
|
|
||||||
vggface_feat = None
|
|
||||||
|
|
||||||
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
|
|
||||||
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, distorted_A, vggface_feat)
|
|
||||||
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, distorted_B, vggface_feat)
|
|
||||||
|
|
||||||
loss_GA += 1e-3 * K.mean(K.abs(mask_A))
|
|
||||||
loss_GB += 1e-3 * K.mean(K.abs(mask_B))
|
|
||||||
|
|
||||||
w_fo = 0.01
|
|
||||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
|
|
||||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
|
|
||||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
|
|
||||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
|
|
||||||
|
|
||||||
weightsDA = self.model.netDA.trainable_weights
|
|
||||||
weightsGA = self.model.netGA.trainable_weights
|
|
||||||
weightsDB = self.model.netDB.trainable_weights
|
|
||||||
weightsGB = self.model.netGB.trainable_weights
|
|
||||||
|
|
||||||
# Adam(..).get_updates(...)
|
|
||||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
|
|
||||||
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
|
|
||||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
|
|
||||||
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
|
|
||||||
|
|
||||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
|
|
||||||
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
|
|
||||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
|
|
||||||
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
|
|
||||||
|
|
||||||
def first_order(self, x, axis=1):
|
|
||||||
img_nrows = x.shape[1]
|
|
||||||
img_ncols = x.shape[2]
|
|
||||||
if axis == 1:
|
|
||||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
|
|
||||||
elif axis == 2:
|
|
||||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def train_one_step(self, iter, viewer):
|
|
||||||
# ---------------------
|
|
||||||
# Train Discriminators
|
|
||||||
# ---------------------
|
|
||||||
|
|
||||||
# Select a random half batch of images
|
|
||||||
epoch, warped_A, target_A = next(self.train_batchA)
|
|
||||||
epoch, warped_B, target_B = next(self.train_batchB)
|
|
||||||
|
|
||||||
# Train dicriminators for one batch
|
|
||||||
errDA = self.netDA_train([warped_A, target_A])
|
|
||||||
errDB = self.netDB_train([warped_B, target_B])
|
|
||||||
|
|
||||||
# Train generators for one batch
|
|
||||||
errGA = self.netGA_train([warped_A, target_A])
|
|
||||||
errGB = self.netGB_train([warped_B, target_B])
|
|
||||||
|
|
||||||
# For calculating average losses
|
|
||||||
self.errDA_sum += errDA[0]
|
|
||||||
self.errDB_sum += errDB[0]
|
|
||||||
self.errGA_sum += errGA[0]
|
|
||||||
self.errGB_sum += errGB[0]
|
|
||||||
self.avg_counter += 1
|
|
||||||
|
|
||||||
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
|
|
||||||
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
self.show_sample(viewer)
|
|
||||||
|
|
||||||
def cycle_variables(self, netG):
|
|
||||||
distorted_input = netG.inputs[0]
|
|
||||||
fake_output = netG.outputs[0]
|
|
||||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
|
|
||||||
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
|
|
||||||
|
|
||||||
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
|
|
||||||
|
|
||||||
fn_generate = K.function([distorted_input], [masked_fake_output])
|
|
||||||
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
|
|
||||||
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
|
|
||||||
fn_bgr = K.function([distorted_input], [rgb])
|
|
||||||
return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
|
|
||||||
|
|
||||||
def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None):
|
|
||||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
|
|
||||||
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
|
|
||||||
fake = alpha * fake_rgb + (1-alpha) * distorted
|
|
||||||
|
|
||||||
if self.use_mixup:
|
|
||||||
dist = Beta(self.mixup_alpha, self.mixup_alpha)
|
|
||||||
lam = dist.sample()
|
|
||||||
# ==========
|
|
||||||
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
|
||||||
# ==========
|
|
||||||
output_mixup = netD(mixup)
|
|
||||||
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
|
|
||||||
output_fake = netD(concatenate([fake, distorted])) # dummy
|
|
||||||
loss_G = .5 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
|
|
||||||
else:
|
|
||||||
output_real = netD(concatenate([real, distorted])) # positive sample
|
|
||||||
output_fake = netD(concatenate([fake, distorted])) # negative sample
|
|
||||||
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
|
|
||||||
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
|
|
||||||
loss_D = loss_D_real + loss_D_fake
|
|
||||||
loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake))
|
|
||||||
# ==========
|
|
||||||
loss_G += K.mean(K.abs(fake_rgb - real))
|
|
||||||
# ==========
|
|
||||||
|
|
||||||
# Edge loss (similar with total variation loss)
|
|
||||||
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=1) - self.first_order(real, axis=1)))
|
|
||||||
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=2) - self.first_order(real, axis=2)))
|
|
||||||
|
|
||||||
|
|
||||||
# Perceptual Loss
|
|
||||||
if not vggface_feat is None:
|
|
||||||
def preprocess_vggface(x):
|
|
||||||
x = (x + 1)/2 * 255 # channel order: BGR
|
|
||||||
#x[..., 0] -= 93.5940
|
|
||||||
#x[..., 1] -= 104.7624
|
|
||||||
#x[..., 2] -= 129.
|
|
||||||
x -= [91.4953, 103.8827, 131.0912]
|
|
||||||
return x
|
|
||||||
pl_params = (0.011, 0.11, 0.1919)
|
|
||||||
real_sz224 = tf.image.resize_images(real, [224, 224])
|
|
||||||
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
|
||||||
# ==========
|
|
||||||
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
|
|
||||||
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
|
||||||
# ==========
|
|
||||||
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
|
|
||||||
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
|
|
||||||
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
|
|
||||||
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
|
|
||||||
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
|
|
||||||
|
|
||||||
return loss_D, loss_G
|
|
||||||
|
|
||||||
def show_sample(self, display_fn):
|
|
||||||
_, wA, tA = next(self.train_batchA)
|
|
||||||
_, wB, tB = next(self.train_batchB)
|
|
||||||
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked")
|
|
||||||
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw")
|
|
||||||
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
|
|
||||||
# Reset the averages
|
|
||||||
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
|
||||||
self.avg_counter = 0
|
|
||||||
|
|
||||||
def showG(self, test_A, test_B, path_A, path_B):
|
|
||||||
figure_A = np.stack([
|
|
||||||
test_A,
|
|
||||||
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
|
||||||
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
|
||||||
], axis=1 )
|
|
||||||
figure_B = np.stack([
|
|
||||||
test_B,
|
|
||||||
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
|
||||||
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
|
||||||
], axis=1 )
|
|
||||||
|
|
||||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
|
||||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
|
||||||
return figure
|
|
||||||
|
|
||||||
def showG_mask(self, test_A, test_B, path_A, path_B):
|
|
||||||
figure_A = np.stack([
|
|
||||||
test_A,
|
|
||||||
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
|
||||||
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
|
||||||
], axis=1 )
|
|
||||||
figure_B = np.stack([
|
|
||||||
test_B,
|
|
||||||
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
|
||||||
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
|
||||||
], axis=1 )
|
|
||||||
|
|
||||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
|
||||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
|
||||||
return figure
|
|
|
@ -1,7 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """Based on https://github.com/shaoanlu/"""
|
|
||||||
__version__ = '0.1.0'
|
|
||||||
|
|
||||||
from .Model import GANModel as Model
|
|
||||||
from .Trainer import Trainer
|
|
|
@ -1,145 +0,0 @@
|
||||||
from keras.engine import Layer, InputSpec
|
|
||||||
from keras import initializers, regularizers, constraints
|
|
||||||
from keras import backend as K
|
|
||||||
from keras.utils.generic_utils import get_custom_objects
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceNormalization(Layer):
|
|
||||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
|
||||||
Normalize the activations of the previous layer at each step,
|
|
||||||
i.e. applies a transformation that maintains the mean activation
|
|
||||||
close to 0 and the activation standard deviation close to 1.
|
|
||||||
# Arguments
|
|
||||||
axis: Integer, the axis that should be normalized
|
|
||||||
(typically the features axis).
|
|
||||||
For instance, after a `Conv2D` layer with
|
|
||||||
`data_format="channels_first"`,
|
|
||||||
set `axis=1` in `InstanceNormalization`.
|
|
||||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
|
||||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
|
||||||
epsilon: Small float added to variance to avoid dividing by zero.
|
|
||||||
center: If True, add offset of `beta` to normalized tensor.
|
|
||||||
If False, `beta` is ignored.
|
|
||||||
scale: If True, multiply by `gamma`.
|
|
||||||
If False, `gamma` is not used.
|
|
||||||
When the next layer is linear (also e.g. `nn.relu`),
|
|
||||||
this can be disabled since the scaling
|
|
||||||
will be done by the next layer.
|
|
||||||
beta_initializer: Initializer for the beta weight.
|
|
||||||
gamma_initializer: Initializer for the gamma weight.
|
|
||||||
beta_regularizer: Optional regularizer for the beta weight.
|
|
||||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
|
||||||
beta_constraint: Optional constraint for the beta weight.
|
|
||||||
gamma_constraint: Optional constraint for the gamma weight.
|
|
||||||
# Input shape
|
|
||||||
Arbitrary. Use the keyword argument `input_shape`
|
|
||||||
(tuple of integers, does not include the samples axis)
|
|
||||||
when using this layer as the first layer in a model.
|
|
||||||
# Output shape
|
|
||||||
Same shape as input.
|
|
||||||
# References
|
|
||||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
|
||||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
axis=None,
|
|
||||||
epsilon=1e-3,
|
|
||||||
center=True,
|
|
||||||
scale=True,
|
|
||||||
beta_initializer='zeros',
|
|
||||||
gamma_initializer='ones',
|
|
||||||
beta_regularizer=None,
|
|
||||||
gamma_regularizer=None,
|
|
||||||
beta_constraint=None,
|
|
||||||
gamma_constraint=None,
|
|
||||||
**kwargs):
|
|
||||||
super(InstanceNormalization, self).__init__(**kwargs)
|
|
||||||
self.supports_masking = True
|
|
||||||
self.axis = axis
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.center = center
|
|
||||||
self.scale = scale
|
|
||||||
self.beta_initializer = initializers.get(beta_initializer)
|
|
||||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
|
||||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
|
||||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
|
||||||
self.beta_constraint = constraints.get(beta_constraint)
|
|
||||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
ndim = len(input_shape)
|
|
||||||
if self.axis == 0:
|
|
||||||
raise ValueError('Axis cannot be zero')
|
|
||||||
|
|
||||||
if (self.axis is not None) and (ndim == 2):
|
|
||||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
|
||||||
|
|
||||||
self.input_spec = InputSpec(ndim=ndim)
|
|
||||||
|
|
||||||
if self.axis is None:
|
|
||||||
shape = (1,)
|
|
||||||
else:
|
|
||||||
shape = (input_shape[self.axis],)
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
self.gamma = self.add_weight(shape=shape,
|
|
||||||
name='gamma',
|
|
||||||
initializer=self.gamma_initializer,
|
|
||||||
regularizer=self.gamma_regularizer,
|
|
||||||
constraint=self.gamma_constraint)
|
|
||||||
else:
|
|
||||||
self.gamma = None
|
|
||||||
if self.center:
|
|
||||||
self.beta = self.add_weight(shape=shape,
|
|
||||||
name='beta',
|
|
||||||
initializer=self.beta_initializer,
|
|
||||||
regularizer=self.beta_regularizer,
|
|
||||||
constraint=self.beta_constraint)
|
|
||||||
else:
|
|
||||||
self.beta = None
|
|
||||||
self.built = True
|
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
|
||||||
input_shape = K.int_shape(inputs)
|
|
||||||
reduction_axes = list(range(0, len(input_shape)))
|
|
||||||
|
|
||||||
if (self.axis is not None):
|
|
||||||
del reduction_axes[self.axis]
|
|
||||||
|
|
||||||
del reduction_axes[0]
|
|
||||||
|
|
||||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
|
||||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
|
||||||
normed = (inputs - mean) / stddev
|
|
||||||
|
|
||||||
broadcast_shape = [1] * len(input_shape)
|
|
||||||
if self.axis is not None:
|
|
||||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
|
||||||
normed = normed * broadcast_gamma
|
|
||||||
if self.center:
|
|
||||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
|
||||||
normed = normed + broadcast_beta
|
|
||||||
return normed
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
'axis': self.axis,
|
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'center': self.center,
|
|
||||||
'scale': self.scale,
|
|
||||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
|
||||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
|
||||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
|
||||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
|
||||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
|
||||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
|
||||||
}
|
|
||||||
base_config = super(InstanceNormalization, self).get_config()
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
|
|
@ -1,204 +0,0 @@
|
||||||
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
|
|
||||||
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from keras.models import Model
|
|
||||||
from keras.layers import *
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.activations import relu
|
|
||||||
from keras.initializers import RandomNormal
|
|
||||||
from keras.applications import *
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
from .instance_normalization import InstanceNormalization
|
|
||||||
from lib.utils import backup_file
|
|
||||||
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
hdf = {'netGAH5':'netGA_GAN128.h5',
|
|
||||||
'netGBH5': 'netGB_GAN128.h5',
|
|
||||||
'netDAH5': 'netDA_GAN128.h5',
|
|
||||||
'netDBH5': 'netDB_GAN128.h5'}
|
|
||||||
|
|
||||||
def __conv_init(a):
|
|
||||||
logger.info("conv_init %s", a)
|
|
||||||
k = RandomNormal(0, 0.02)(a) # for convolution kernel
|
|
||||||
k.conv_weight = True
|
|
||||||
return k
|
|
||||||
|
|
||||||
#def batchnorm():
|
|
||||||
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
|
|
||||||
|
|
||||||
def inst_norm():
|
|
||||||
return InstanceNormalization()
|
|
||||||
|
|
||||||
conv_init = RandomNormal(0, 0.02)
|
|
||||||
gamma_init = RandomNormal(1., 0.02) # for batch normalization
|
|
||||||
|
|
||||||
class GANModel():
|
|
||||||
img_size = 128
|
|
||||||
channels = 3
|
|
||||||
img_shape = (img_size, img_size, channels)
|
|
||||||
encoded_dim = 1024
|
|
||||||
nc_in = 3 # number of input channels of generators
|
|
||||||
nc_D_inp = 6 # number of input channels of discriminators
|
|
||||||
|
|
||||||
def __init__(self, model_dir, gpus):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.gpus = gpus
|
|
||||||
|
|
||||||
optimizer = Adam(1e-4, 0.5)
|
|
||||||
|
|
||||||
# Build and compile the discriminator
|
|
||||||
self.netDA, self.netDB = self.build_discriminator()
|
|
||||||
|
|
||||||
# Build and compile the generator
|
|
||||||
self.netGA, self.netGB = self.build_generator()
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
predictor = self.netGB if not swap else self.netGA
|
|
||||||
return lambda img: predictor.predict(img)
|
|
||||||
|
|
||||||
def build_generator(self):
|
|
||||||
|
|
||||||
def conv_block(input_tensor, f, use_instance_norm=True):
|
|
||||||
x = input_tensor
|
|
||||||
x = SeparableConv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
if use_instance_norm:
|
|
||||||
x = inst_norm()(x)
|
|
||||||
x = Activation("relu")(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def res_block(input_tensor, f, dilation=1):
|
|
||||||
x = input_tensor
|
|
||||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
|
|
||||||
x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
|
|
||||||
x = add([x, input_tensor])
|
|
||||||
#x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def upscale_ps(filters, use_instance_norm=True):
|
|
||||||
def block(x, use_instance_norm=use_instance_norm):
|
|
||||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
|
|
||||||
if use_instance_norm:
|
|
||||||
x = inst_norm()(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder(nc_in=3, input_size=128):
|
|
||||||
inp = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
x = Conv2D(32, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
|
|
||||||
x = conv_block(x,64, use_instance_norm=False)
|
|
||||||
x = conv_block(x,128)
|
|
||||||
x = conv_block(x,256)
|
|
||||||
x = conv_block(x,512)
|
|
||||||
x = conv_block(x,1024)
|
|
||||||
x = Dense(1024)(Flatten()(x))
|
|
||||||
x = Dense(4*4*1024)(x)
|
|
||||||
x = Reshape((4, 4, 1024))(x)
|
|
||||||
out = upscale_ps(512)(x)
|
|
||||||
return Model(inputs=inp, outputs=out)
|
|
||||||
|
|
||||||
def Decoder_ps(nc_in=512, input_size=8):
|
|
||||||
input_ = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
x = input_
|
|
||||||
x = upscale_ps(256)(x)
|
|
||||||
x = upscale_ps(128)(x)
|
|
||||||
x = upscale_ps(64)(x)
|
|
||||||
x = res_block(x, 64, dilation=2)
|
|
||||||
|
|
||||||
out64 = Conv2D(64, kernel_size=3, padding='same')(x)
|
|
||||||
out64 = LeakyReLU(alpha=0.1)(out64)
|
|
||||||
out64 = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(out64)
|
|
||||||
|
|
||||||
x = upscale_ps(32)(x)
|
|
||||||
x = res_block(x, 32)
|
|
||||||
x = res_block(x, 32)
|
|
||||||
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
|
|
||||||
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
|
|
||||||
out = concatenate([alpha, rgb])
|
|
||||||
return Model(input_, [out, out64] )
|
|
||||||
|
|
||||||
encoder = Encoder()
|
|
||||||
decoder_A = Decoder_ps()
|
|
||||||
decoder_B = Decoder_ps()
|
|
||||||
x = Input(shape=self.img_shape)
|
|
||||||
netGA = Model(x, decoder_A(encoder(x)))
|
|
||||||
netGB = Model(x, decoder_B(encoder(x)))
|
|
||||||
netGA.output_names = ["netGA_out_1", "netGA_out_2"] # Workarounds till https://github.com/keras-team/keras/issues/8962 is fixed.
|
|
||||||
netGB.output_names = ["netGB_out_1", "netGB_out_2"] #
|
|
||||||
|
|
||||||
self.netGA_sm = netGA
|
|
||||||
self.netGB_sm = netGB
|
|
||||||
|
|
||||||
try:
|
|
||||||
netGA.load_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
netGB.load_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
logger.info("Generator models loaded.")
|
|
||||||
except:
|
|
||||||
logger.info("Generator weights files not found.")
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
netGA = multi_gpu_model( self.netGA_sm , self.gpus)
|
|
||||||
netGB = multi_gpu_model( self.netGB_sm , self.gpus)
|
|
||||||
|
|
||||||
return netGA, netGB
|
|
||||||
|
|
||||||
def build_discriminator(self):
|
|
||||||
def conv_block_d(input_tensor, f, use_instance_norm=True):
|
|
||||||
x = input_tensor
|
|
||||||
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
|
||||||
if use_instance_norm:
|
|
||||||
x = inst_norm()(x)
|
|
||||||
x = LeakyReLU(alpha=0.2)(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def Discriminator(nc_in, input_size=128):
|
|
||||||
inp = Input(shape=(input_size, input_size, nc_in))
|
|
||||||
#x = GaussianNoise(0.05)(inp)
|
|
||||||
x = conv_block_d(inp, 64, False)
|
|
||||||
x = conv_block_d(x, 128, True)
|
|
||||||
x = conv_block_d(x, 256, True)
|
|
||||||
x = conv_block_d(x, 512, True)
|
|
||||||
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
|
|
||||||
return Model(inputs=[inp], outputs=out)
|
|
||||||
|
|
||||||
netDA = Discriminator(self.nc_D_inp)
|
|
||||||
netDB = Discriminator(self.nc_D_inp)
|
|
||||||
|
|
||||||
try:
|
|
||||||
netDA.load_weights(str(self.model_dir / hdf['netDAH5']))
|
|
||||||
netDB.load_weights(str(self.model_dir / hdf['netDBH5']))
|
|
||||||
logger.info("Discriminator models loaded.")
|
|
||||||
except:
|
|
||||||
logger.info("Discriminator weights files not found.")
|
|
||||||
pass
|
|
||||||
return netDA, netDB
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
if swapped:
|
|
||||||
logger.warning("swapping not supported on GAN")
|
|
||||||
# TODO load is done in __init__ => look how to swap if possible
|
|
||||||
return True
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
else:
|
|
||||||
self.netGA.save_weights(str(self.model_dir / hdf['netGAH5']))
|
|
||||||
self.netGB.save_weights(str(self.model_dir / hdf['netGBH5']))
|
|
||||||
self.netDA.save_weights(str(self.model_dir / hdf['netDAH5']))
|
|
||||||
self.netDB.save_weights(str(self.model_dir / hdf['netDBH5']))
|
|
||||||
logger.info("Models saved.")
|
|
|
@ -1,263 +0,0 @@
|
||||||
import time
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from keras.layers import *
|
|
||||||
from tensorflow.contrib.distributions import Beta
|
|
||||||
import tensorflow as tf
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
from keras import backend as K
|
|
||||||
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
class GANTrainingDataGenerator(TrainingDataGenerator):
|
|
||||||
def __init__(self, random_transform_args, coverage, scale, zoom):
|
|
||||||
super().__init__(random_transform_args, coverage, scale, zoom)
|
|
||||||
|
|
||||||
def color_adjust(self, img):
|
|
||||||
return img / 255.0 * 2 - 1
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 20,
|
|
||||||
'zoom_range': 0.1,
|
|
||||||
'shift_range': 0.05,
|
|
||||||
'random_flip': 0.5,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
|
|
||||||
K.set_learning_phase(1)
|
|
||||||
|
|
||||||
assert batch_size % 2 == 0, "batch_size must be an even number"
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.use_lsgan = True
|
|
||||||
self.use_mixup = True
|
|
||||||
self.mixup_alpha = 0.2
|
|
||||||
self.use_perceptual_loss = perceptual_loss
|
|
||||||
self.use_mask_refinement = False #OPTIONAL After 15k iteration**
|
|
||||||
|
|
||||||
self.lrD = 1e-4 # Discriminator learning rate
|
|
||||||
self.lrG = 1e-4 # Generator learning rate
|
|
||||||
|
|
||||||
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 2)
|
|
||||||
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
|
|
||||||
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
|
|
||||||
|
|
||||||
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
|
||||||
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
distorted_A, fake_A, fake_sz64_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
|
|
||||||
distorted_B, fake_B, fake_sz64_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
|
|
||||||
real_A = Input(shape=self.model.img_shape)
|
|
||||||
real_B = Input(shape=self.model.img_shape)
|
|
||||||
|
|
||||||
if self.use_lsgan:
|
|
||||||
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
|
|
||||||
else:
|
|
||||||
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
|
|
||||||
|
|
||||||
# ========== Define Perceptual Loss Model==========
|
|
||||||
if self.use_perceptual_loss:
|
|
||||||
from keras.models import Model
|
|
||||||
from keras_vggface.vggface import VGGFace
|
|
||||||
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
|
|
||||||
vggface.trainable = False
|
|
||||||
out_size55 = vggface.layers[36].output
|
|
||||||
out_size28 = vggface.layers[78].output
|
|
||||||
out_size7 = vggface.layers[-2].output
|
|
||||||
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
|
|
||||||
vggface_feat.trainable = False
|
|
||||||
else:
|
|
||||||
vggface_feat = None
|
|
||||||
|
|
||||||
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, fake_sz64_A, distorted_A, vggface_feat)
|
|
||||||
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, fake_sz64_B, distorted_B, vggface_feat)
|
|
||||||
|
|
||||||
if self.use_mask_refinement:
|
|
||||||
loss_GA += 1e-3 * K.mean(K.square(mask_A))
|
|
||||||
loss_GB += 1e-3 * K.mean(K.square(mask_B))
|
|
||||||
else:
|
|
||||||
loss_GA += 3e-3 * K.mean(K.abs(mask_A))
|
|
||||||
loss_GB += 3e-3 * K.mean(K.abs(mask_B))
|
|
||||||
|
|
||||||
w_fo = 0.01
|
|
||||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
|
|
||||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
|
|
||||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
|
|
||||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
|
|
||||||
|
|
||||||
weightsDA = self.model.netDA.trainable_weights
|
|
||||||
weightsGA = self.model.netGA.trainable_weights
|
|
||||||
weightsDB = self.model.netDB.trainable_weights
|
|
||||||
weightsGB = self.model.netGB.trainable_weights
|
|
||||||
|
|
||||||
# Adam(..).get_updates(...)
|
|
||||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
|
|
||||||
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
|
|
||||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
|
|
||||||
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
|
|
||||||
|
|
||||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
|
|
||||||
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
|
|
||||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
|
|
||||||
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
|
|
||||||
|
|
||||||
def first_order(self, x, axis=1):
|
|
||||||
img_nrows = x.shape[1]
|
|
||||||
img_ncols = x.shape[2]
|
|
||||||
if axis == 1:
|
|
||||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
|
|
||||||
elif axis == 2:
|
|
||||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def train_one_step(self, iter, viewer):
|
|
||||||
# ---------------------
|
|
||||||
# Train Discriminators
|
|
||||||
# ---------------------
|
|
||||||
|
|
||||||
# Select a random half batch of images
|
|
||||||
epoch, warped_A, target_A = next(self.train_batchA)
|
|
||||||
epoch, warped_B, target_B = next(self.train_batchB)
|
|
||||||
|
|
||||||
# Train dicriminators for one batch
|
|
||||||
errDA = self.netDA_train([warped_A, target_A])
|
|
||||||
errDB = self.netDB_train([warped_B, target_B])
|
|
||||||
|
|
||||||
# Train generators for one batch
|
|
||||||
errGA = self.netGA_train([warped_A, target_A])
|
|
||||||
errGB = self.netGB_train([warped_B, target_B])
|
|
||||||
|
|
||||||
# For calculating average losses
|
|
||||||
self.errDA_sum += errDA[0]
|
|
||||||
self.errDB_sum += errDB[0]
|
|
||||||
self.errGA_sum += errGA[0]
|
|
||||||
self.errGB_sum += errGB[0]
|
|
||||||
self.avg_counter += 1
|
|
||||||
|
|
||||||
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
|
|
||||||
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
self.show_sample(viewer)
|
|
||||||
|
|
||||||
def cycle_variables(self, netG):
|
|
||||||
distorted_input = netG.inputs[0]
|
|
||||||
fake_output = netG.outputs[0]
|
|
||||||
fake_output64 = netG.outputs[1]
|
|
||||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
|
|
||||||
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
|
|
||||||
|
|
||||||
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
|
|
||||||
|
|
||||||
fn_generate = K.function([distorted_input], [masked_fake_output])
|
|
||||||
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
|
|
||||||
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
|
|
||||||
fn_bgr = K.function([distorted_input], [rgb])
|
|
||||||
return distorted_input, fake_output, fake_output64, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
|
|
||||||
|
|
||||||
def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None):
|
|
||||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
|
|
||||||
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
|
|
||||||
fake = alpha * fake_rgb + (1-alpha) * distorted
|
|
||||||
|
|
||||||
if self.use_mixup:
|
|
||||||
dist = Beta(self.mixup_alpha, self.mixup_alpha)
|
|
||||||
lam = dist.sample()
|
|
||||||
# ==========
|
|
||||||
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
|
||||||
# ==========
|
|
||||||
output_mixup = netD(mixup)
|
|
||||||
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
|
|
||||||
#output_fake = netD(concatenate([fake, distorted])) # dummy
|
|
||||||
loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
|
|
||||||
else:
|
|
||||||
output_real = netD(concatenate([real, distorted])) # positive sample
|
|
||||||
output_fake = netD(concatenate([fake, distorted])) # negative sample
|
|
||||||
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
|
|
||||||
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
|
|
||||||
loss_D = loss_D_real + loss_D_fake
|
|
||||||
loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake))
|
|
||||||
# ==========
|
|
||||||
if self.use_mask_refinement:
|
|
||||||
loss_G += K.mean(K.abs(fake - real))
|
|
||||||
else:
|
|
||||||
loss_G += K.mean(K.abs(fake_rgb - real))
|
|
||||||
loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
|
|
||||||
# ==========
|
|
||||||
|
|
||||||
# Perceptual Loss
|
|
||||||
if not vggface_feat is None:
|
|
||||||
def preprocess_vggface(x):
|
|
||||||
x = (x + 1)/2 * 255 # channel order: BGR
|
|
||||||
x -= [93.5940, 104.7624, 129.]
|
|
||||||
return x
|
|
||||||
pl_params = (0.02, 0.3, 0.5)
|
|
||||||
real_sz224 = tf.image.resize_images(real, [224, 224])
|
|
||||||
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
|
||||||
# ==========
|
|
||||||
if self.use_mask_refinement:
|
|
||||||
fake_sz224 = tf.image.resize_images(fake, [224, 224])
|
|
||||||
else:
|
|
||||||
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
|
|
||||||
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
|
||||||
# ==========
|
|
||||||
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
|
|
||||||
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
|
|
||||||
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
|
|
||||||
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
|
|
||||||
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
|
|
||||||
|
|
||||||
return loss_D, loss_G
|
|
||||||
|
|
||||||
def show_sample(self, display_fn):
|
|
||||||
_, wA, tA = next(self.train_batchA)
|
|
||||||
_, wB, tB = next(self.train_batchB)
|
|
||||||
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked")
|
|
||||||
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw")
|
|
||||||
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
|
|
||||||
# Reset the averages
|
|
||||||
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
|
||||||
self.avg_counter = 0
|
|
||||||
|
|
||||||
def showG(self, test_A, test_B, path_A, path_B):
|
|
||||||
figure_A = np.stack([
|
|
||||||
test_A,
|
|
||||||
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
|
||||||
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
|
||||||
], axis=1 )
|
|
||||||
figure_B = np.stack([
|
|
||||||
test_B,
|
|
||||||
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
|
||||||
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
|
||||||
], axis=1 )
|
|
||||||
|
|
||||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
|
||||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
|
||||||
return figure
|
|
||||||
|
|
||||||
def showG_mask(self, test_A, test_B, path_A, path_B):
|
|
||||||
figure_A = np.stack([
|
|
||||||
test_A,
|
|
||||||
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
|
||||||
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
|
||||||
], axis=1 )
|
|
||||||
figure_B = np.stack([
|
|
||||||
test_B,
|
|
||||||
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
|
||||||
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
|
||||||
], axis=1 )
|
|
||||||
|
|
||||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
|
||||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
|
||||||
return figure
|
|
|
@ -1,7 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """Based on https://github.com/shaoanlu/"""
|
|
||||||
__version__ = '0.1.0'
|
|
||||||
|
|
||||||
from .Model import GANModel as Model
|
|
||||||
from .Trainer import Trainer
|
|
|
@ -1,145 +0,0 @@
|
||||||
from keras.engine import Layer, InputSpec
|
|
||||||
from keras import initializers, regularizers, constraints
|
|
||||||
from keras import backend as K
|
|
||||||
from keras.utils.generic_utils import get_custom_objects
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceNormalization(Layer):
|
|
||||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
|
||||||
Normalize the activations of the previous layer at each step,
|
|
||||||
i.e. applies a transformation that maintains the mean activation
|
|
||||||
close to 0 and the activation standard deviation close to 1.
|
|
||||||
# Arguments
|
|
||||||
axis: Integer, the axis that should be normalized
|
|
||||||
(typically the features axis).
|
|
||||||
For instance, after a `Conv2D` layer with
|
|
||||||
`data_format="channels_first"`,
|
|
||||||
set `axis=1` in `InstanceNormalization`.
|
|
||||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
|
||||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
|
||||||
epsilon: Small float added to variance to avoid dividing by zero.
|
|
||||||
center: If True, add offset of `beta` to normalized tensor.
|
|
||||||
If False, `beta` is ignored.
|
|
||||||
scale: If True, multiply by `gamma`.
|
|
||||||
If False, `gamma` is not used.
|
|
||||||
When the next layer is linear (also e.g. `nn.relu`),
|
|
||||||
this can be disabled since the scaling
|
|
||||||
will be done by the next layer.
|
|
||||||
beta_initializer: Initializer for the beta weight.
|
|
||||||
gamma_initializer: Initializer for the gamma weight.
|
|
||||||
beta_regularizer: Optional regularizer for the beta weight.
|
|
||||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
|
||||||
beta_constraint: Optional constraint for the beta weight.
|
|
||||||
gamma_constraint: Optional constraint for the gamma weight.
|
|
||||||
# Input shape
|
|
||||||
Arbitrary. Use the keyword argument `input_shape`
|
|
||||||
(tuple of integers, does not include the samples axis)
|
|
||||||
when using this layer as the first layer in a model.
|
|
||||||
# Output shape
|
|
||||||
Same shape as input.
|
|
||||||
# References
|
|
||||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
|
||||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
axis=None,
|
|
||||||
epsilon=1e-3,
|
|
||||||
center=True,
|
|
||||||
scale=True,
|
|
||||||
beta_initializer='zeros',
|
|
||||||
gamma_initializer='ones',
|
|
||||||
beta_regularizer=None,
|
|
||||||
gamma_regularizer=None,
|
|
||||||
beta_constraint=None,
|
|
||||||
gamma_constraint=None,
|
|
||||||
**kwargs):
|
|
||||||
super(InstanceNormalization, self).__init__(**kwargs)
|
|
||||||
self.supports_masking = True
|
|
||||||
self.axis = axis
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.center = center
|
|
||||||
self.scale = scale
|
|
||||||
self.beta_initializer = initializers.get(beta_initializer)
|
|
||||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
|
||||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
|
||||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
|
||||||
self.beta_constraint = constraints.get(beta_constraint)
|
|
||||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
ndim = len(input_shape)
|
|
||||||
if self.axis == 0:
|
|
||||||
raise ValueError('Axis cannot be zero')
|
|
||||||
|
|
||||||
if (self.axis is not None) and (ndim == 2):
|
|
||||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
|
||||||
|
|
||||||
self.input_spec = InputSpec(ndim=ndim)
|
|
||||||
|
|
||||||
if self.axis is None:
|
|
||||||
shape = (1,)
|
|
||||||
else:
|
|
||||||
shape = (input_shape[self.axis],)
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
self.gamma = self.add_weight(shape=shape,
|
|
||||||
name='gamma',
|
|
||||||
initializer=self.gamma_initializer,
|
|
||||||
regularizer=self.gamma_regularizer,
|
|
||||||
constraint=self.gamma_constraint)
|
|
||||||
else:
|
|
||||||
self.gamma = None
|
|
||||||
if self.center:
|
|
||||||
self.beta = self.add_weight(shape=shape,
|
|
||||||
name='beta',
|
|
||||||
initializer=self.beta_initializer,
|
|
||||||
regularizer=self.beta_regularizer,
|
|
||||||
constraint=self.beta_constraint)
|
|
||||||
else:
|
|
||||||
self.beta = None
|
|
||||||
self.built = True
|
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
|
||||||
input_shape = K.int_shape(inputs)
|
|
||||||
reduction_axes = list(range(0, len(input_shape)))
|
|
||||||
|
|
||||||
if (self.axis is not None):
|
|
||||||
del reduction_axes[self.axis]
|
|
||||||
|
|
||||||
del reduction_axes[0]
|
|
||||||
|
|
||||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
|
||||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
|
||||||
normed = (inputs - mean) / stddev
|
|
||||||
|
|
||||||
broadcast_shape = [1] * len(input_shape)
|
|
||||||
if self.axis is not None:
|
|
||||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
|
||||||
normed = normed * broadcast_gamma
|
|
||||||
if self.center:
|
|
||||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
|
||||||
normed = normed + broadcast_beta
|
|
||||||
return normed
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
'axis': self.axis,
|
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'center': self.center,
|
|
||||||
'scale': self.scale,
|
|
||||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
|
||||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
|
||||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
|
||||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
|
||||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
|
||||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
|
||||||
}
|
|
||||||
base_config = super(InstanceNormalization, self).get_config()
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
|
|
@ -1,53 +0,0 @@
|
||||||
# Improved-AutoEncoder base classes
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from lib.utils import backup_file
|
|
||||||
|
|
||||||
hdf = {'encoderH5': 'IAE_encoder.h5',
|
|
||||||
'decoderH5': 'IAE_decoder.h5',
|
|
||||||
'inter_AH5': 'IAE_inter_A.h5',
|
|
||||||
'inter_BH5': 'IAE_inter_B.h5',
|
|
||||||
'inter_bothH5': 'IAE_inter_both.h5'}
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class AutoEncoder:
|
|
||||||
def __init__(self, model_dir, gpus):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.gpus = gpus
|
|
||||||
|
|
||||||
self.encoder = self.Encoder()
|
|
||||||
self.decoder = self.Decoder()
|
|
||||||
self.inter_A = self.Intermidiate()
|
|
||||||
self.inter_B = self.Intermidiate()
|
|
||||||
self.inter_both = self.Intermidiate()
|
|
||||||
|
|
||||||
self.initModel()
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
(face_A,face_B) = (hdf['inter_AH5'], hdf['inter_BH5']) if not swapped else (hdf['inter_BH5'], hdf['inter_AH5'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder.load_weights(str(self.model_dir / hdf['decoderH5']))
|
|
||||||
self.inter_both.load_weights(str(self.model_dir / hdf['inter_bothH5']))
|
|
||||||
self.inter_A.load_weights(str(self.model_dir / face_A))
|
|
||||||
self.inter_B.load_weights(str(self.model_dir / face_B))
|
|
||||||
logger.info('loaded model weights')
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder.save_weights(str(self.model_dir / hdf['decoderH5']))
|
|
||||||
self.inter_both.save_weights(str(self.model_dir / hdf['inter_bothH5']))
|
|
||||||
self.inter_A.save_weights(str(self.model_dir / hdf['inter_AH5']))
|
|
||||||
self.inter_B.save_weights(str(self.model_dir / hdf['inter_BH5']))
|
|
||||||
logger.info('saved model weights')
|
|
|
@ -1,77 +0,0 @@
|
||||||
# Improved autoencoder for faceswap.
|
|
||||||
|
|
||||||
from keras.models import Model as KerasModel
|
|
||||||
from keras.layers import Input, Dense, Flatten, Reshape, Concatenate
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.layers.convolutional import Conv2D
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
IMAGE_SHAPE = (64, 64, 3)
|
|
||||||
ENCODER_DIM = 1024
|
|
||||||
|
|
||||||
class Model(AutoEncoder):
|
|
||||||
def initModel(self):
|
|
||||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
|
||||||
x = Input(shape=IMAGE_SHAPE)
|
|
||||||
|
|
||||||
self.autoencoder_A = KerasModel(x, self.decoder(Concatenate()([self.inter_A(self.encoder(x)), self.inter_both(self.encoder(x))])))
|
|
||||||
self.autoencoder_B = KerasModel(x, self.decoder(Concatenate()([self.inter_B(self.encoder(x)), self.inter_both(self.encoder(x))])))
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
|
|
||||||
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
|
|
||||||
|
|
||||||
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
|
|
||||||
return lambda img: autoencoder.predict(img)
|
|
||||||
|
|
||||||
def conv(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def upscale(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder(self):
|
|
||||||
input_ = Input(shape=IMAGE_SHAPE)
|
|
||||||
x = input_
|
|
||||||
x = self.conv(128)(x)
|
|
||||||
x = self.conv(256)(x)
|
|
||||||
x = self.conv(512)(x)
|
|
||||||
x = self.conv(1024)(x)
|
|
||||||
x = Flatten()(x)
|
|
||||||
return KerasModel(input_, x)
|
|
||||||
|
|
||||||
def Intermidiate(self):
|
|
||||||
input_ = Input(shape=(None, 4 * 4 * 1024))
|
|
||||||
x = input_
|
|
||||||
x = Dense(ENCODER_DIM)(x)
|
|
||||||
x = Dense(4 * 4 * int(ENCODER_DIM/2))(x)
|
|
||||||
x = Reshape((4, 4, int(ENCODER_DIM/2)))(x)
|
|
||||||
return KerasModel(input_, x)
|
|
||||||
|
|
||||||
def Decoder(self):
|
|
||||||
input_ = Input(shape=(4, 4, ENCODER_DIM))
|
|
||||||
x = input_
|
|
||||||
x = self.upscale(512)(x)
|
|
||||||
x = self.upscale(256)(x)
|
|
||||||
x = self.upscale(128)(x)
|
|
||||||
x = self.upscale(64)(x)
|
|
||||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
|
||||||
return KerasModel(input_, x)
|
|
|
@ -1,51 +0,0 @@
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 10,
|
|
||||||
'zoom_range': 0.05,
|
|
||||||
'shift_range': 0.05,
|
|
||||||
'random_flip': 0.4,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, *args):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
generator = TrainingDataGenerator(self.random_transform_args, 160)
|
|
||||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
|
||||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
|
||||||
|
|
||||||
def train_one_step(self, iter, viewer):
|
|
||||||
epoch, warped_A, target_A = next(self.images_A)
|
|
||||||
epoch, warped_B, target_B = next(self.images_B)
|
|
||||||
|
|
||||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
|
||||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
|
||||||
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
|
|
||||||
|
|
||||||
def show_sample(self, test_A, test_B):
|
|
||||||
figure_A = numpy.stack([
|
|
||||||
test_A,
|
|
||||||
self.model.autoencoder_A.predict(test_A),
|
|
||||||
self.model.autoencoder_B.predict(test_A),
|
|
||||||
], axis=1)
|
|
||||||
figure_B = numpy.stack([
|
|
||||||
test_B,
|
|
||||||
self.model.autoencoder_B.predict(test_B),
|
|
||||||
self.model.autoencoder_A.predict(test_B),
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
|
||||||
figure = figure.reshape((4, 7) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
|
|
||||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
|
|
@ -1,8 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """acsaga"""
|
|
||||||
__version__ = '0.1.0'
|
|
||||||
|
|
||||||
from .Model import Model
|
|
||||||
from .Trainer import Trainer
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
|
@ -1,61 +0,0 @@
|
||||||
# AutoEncoder base classes
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from lib.utils import backup_file
|
|
||||||
|
|
||||||
hdf = {'encoderH5': 'lowmem_encoder.h5',
|
|
||||||
'decoder_AH5': 'lowmem_decoder_A.h5',
|
|
||||||
'decoder_BH5': 'lowmem_decoder_B.h5'}
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
#Part of Filename migration, should be remopved some reasonable time after first added
|
|
||||||
import os.path
|
|
||||||
old_encoderH5 = 'encoder.h5'
|
|
||||||
old_decoder_AH5 = 'decoder_A.h5'
|
|
||||||
old_decoder_BH5 = 'decoder_B.h5'
|
|
||||||
#End filename migration
|
|
||||||
|
|
||||||
class AutoEncoder:
|
|
||||||
def __init__(self, model_dir, gpus):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.gpus = gpus
|
|
||||||
|
|
||||||
self.encoder = self.Encoder()
|
|
||||||
self.decoder_A = self.Decoder()
|
|
||||||
self.decoder_B = self.Decoder()
|
|
||||||
|
|
||||||
self.initModel()
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
(face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
#Part of Filename migration, should be remopved some reasonable time after first added
|
|
||||||
if os.path.isfile(str(self.model_dir / old_encoderH5)):
|
|
||||||
logger.info('Migrating to new filenames:')
|
|
||||||
if os.path.isfile(str(self.model_dir / hdf['encoderH5'])) is not True:
|
|
||||||
os.rename(str(self.model_dir / old_decoder_AH5), str(self.model_dir / hdf['decoder_AH5']))
|
|
||||||
os.rename(str(self.model_dir / old_decoder_BH5), str(self.model_dir / hdf['decoder_BH5']))
|
|
||||||
os.rename(str(self.model_dir / old_encoderH5), str(self.model_dir / hdf['encoderH5']))
|
|
||||||
logger.info('Complete')
|
|
||||||
else:
|
|
||||||
logger.warning('Failed due to existing files in folder. Loading already migrated files')
|
|
||||||
#End filename migration
|
|
||||||
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder_A.load_weights(str(self.model_dir / face_A))
|
|
||||||
self.decoder_B.load_weights(str(self.model_dir / face_B))
|
|
||||||
logger.info('loaded model weights')
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
|
|
||||||
self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
|
|
||||||
logger.info('saved model weights')
|
|
|
@ -1,70 +0,0 @@
|
||||||
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
|
||||||
|
|
||||||
from keras.models import Model as KerasModel
|
|
||||||
from keras.layers import Input, Dense, Flatten, Reshape
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.layers.convolutional import Conv2D
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
IMAGE_SHAPE = (64, 64, 3)
|
|
||||||
ENCODER_DIM = 512
|
|
||||||
|
|
||||||
class Model(AutoEncoder):
|
|
||||||
def initModel(self):
|
|
||||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
|
||||||
x = Input(shape=IMAGE_SHAPE)
|
|
||||||
|
|
||||||
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
|
|
||||||
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
|
|
||||||
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
|
|
||||||
|
|
||||||
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
|
|
||||||
return lambda img: autoencoder.predict(img)
|
|
||||||
|
|
||||||
def conv(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def upscale(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder(self):
|
|
||||||
input_ = Input(shape=IMAGE_SHAPE)
|
|
||||||
x = input_
|
|
||||||
x = self.conv(128)(x)
|
|
||||||
x = self.conv(256)(x)
|
|
||||||
x = self.conv(512)(x)
|
|
||||||
x = Dense(ENCODER_DIM)(Flatten()(x))
|
|
||||||
x = Dense(4 * 4 * 1024)(x)
|
|
||||||
x = Reshape((4, 4, 1024))(x)
|
|
||||||
x = self.upscale(512)(x)
|
|
||||||
return KerasModel(input_, x)
|
|
||||||
|
|
||||||
def Decoder(self):
|
|
||||||
input_ = Input(shape=(8, 8, 512))
|
|
||||||
x = input_
|
|
||||||
x = self.upscale(256)(x)
|
|
||||||
x = self.upscale(128)(x)
|
|
||||||
x = self.upscale(64)(x)
|
|
||||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
|
||||||
return KerasModel(input_, x)
|
|
|
@ -1,56 +0,0 @@
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 10,
|
|
||||||
'zoom_range': 0.05,
|
|
||||||
'shift_range': 0.05,
|
|
||||||
'random_flip': 0.4,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, *args):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
generator = TrainingDataGenerator(self.random_transform_args, 160)
|
|
||||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
|
||||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
|
||||||
|
|
||||||
def train_one_step(self, iter, viewer):
|
|
||||||
epoch, warped_A, target_A = next(self.images_A)
|
|
||||||
epoch, warped_B, target_B = next(self.images_B)
|
|
||||||
|
|
||||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
|
||||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
|
||||||
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
|
|
||||||
|
|
||||||
def show_sample(self, test_A, test_B):
|
|
||||||
figure_A = numpy.stack([
|
|
||||||
test_A,
|
|
||||||
self.model.autoencoder_A.predict(test_A),
|
|
||||||
self.model.autoencoder_B.predict(test_A),
|
|
||||||
], axis=1)
|
|
||||||
figure_B = numpy.stack([
|
|
||||||
test_B,
|
|
||||||
self.model.autoencoder_B.predict(test_B),
|
|
||||||
self.model.autoencoder_A.predict(test_B),
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
if test_A.shape[0] % 2 == 1:
|
|
||||||
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
|
|
||||||
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
|
|
||||||
|
|
||||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
|
||||||
w = 4
|
|
||||||
h = int( figure.shape[0] / w)
|
|
||||||
figure = figure.reshape((w, h) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
|
|
||||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
|
|
@ -1,8 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """Based on https://reddit.com/u/deepfakes/"""
|
|
||||||
__version__ = '0.1.0'
|
|
||||||
|
|
||||||
from .Model import Model
|
|
||||||
from .Trainer import Trainer
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
|
@ -1,77 +0,0 @@
|
||||||
# AutoEncoder base classes
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from lib.utils import backup_file
|
|
||||||
from lib import Serializer
|
|
||||||
from json import JSONDecodeError
|
|
||||||
|
|
||||||
hdf = {'encoderH5': 'encoder.h5',
|
|
||||||
'decoder_AH5': 'decoder_A.h5',
|
|
||||||
'decoder_BH5': 'decoder_B.h5',
|
|
||||||
'state': 'state'}
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class AutoEncoder:
|
|
||||||
def __init__(self, model_dir, gpus):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.gpus = gpus
|
|
||||||
|
|
||||||
self.encoder = self.Encoder()
|
|
||||||
self.decoder_A = self.Decoder()
|
|
||||||
self.decoder_B = self.Decoder()
|
|
||||||
|
|
||||||
self.initModel()
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
serializer = Serializer.get_serializer('json')
|
|
||||||
state_fn = ".".join([hdf['state'], serializer.ext])
|
|
||||||
try:
|
|
||||||
with open(str(self.model_dir / state_fn), 'rb') as fp:
|
|
||||||
state = serializer.unmarshal(fp.read().decode('utf-8'))
|
|
||||||
self._epoch_no = state['epoch_no']
|
|
||||||
except IOError as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e.strerror))
|
|
||||||
self._epoch_no = 0
|
|
||||||
except JSONDecodeError as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e.msg))
|
|
||||||
self._epoch_no = 0
|
|
||||||
|
|
||||||
(face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder_A.load_weights(str(self.model_dir / face_A))
|
|
||||||
self.decoder_B.load_weights(str(self.model_dir / face_B))
|
|
||||||
logger.info('loaded model weights')
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
|
|
||||||
self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
|
|
||||||
self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
|
|
||||||
|
|
||||||
logger.info('saved model weights')
|
|
||||||
|
|
||||||
serializer = Serializer.get_serializer('json')
|
|
||||||
state_fn = ".".join([hdf['state'], serializer.ext])
|
|
||||||
state_dir = str(self.model_dir / state_fn)
|
|
||||||
try:
|
|
||||||
with open(state_dir, 'wb') as fp:
|
|
||||||
state_json = serializer.marshal({
|
|
||||||
'epoch_no' : self.epoch_no
|
|
||||||
})
|
|
||||||
fp.write(state_json.encode('utf-8'))
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(e.strerror)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def epoch_no(self):
|
|
||||||
"Get current training epoch number"
|
|
||||||
return self._epoch_no
|
|
|
@ -1,71 +0,0 @@
|
||||||
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
|
||||||
|
|
||||||
from keras.models import Model as KerasModel
|
|
||||||
from keras.layers import Input, Dense, Flatten, Reshape
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.layers.convolutional import Conv2D
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
IMAGE_SHAPE = (64, 64, 3)
|
|
||||||
ENCODER_DIM = 1024
|
|
||||||
|
|
||||||
class Model(AutoEncoder):
|
|
||||||
def initModel(self):
|
|
||||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
|
||||||
x = Input(shape=IMAGE_SHAPE)
|
|
||||||
|
|
||||||
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
|
|
||||||
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
|
|
||||||
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
|
|
||||||
|
|
||||||
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
|
|
||||||
return lambda img: autoencoder.predict(img)
|
|
||||||
|
|
||||||
def conv(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def upscale(self, filters):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder(self):
|
|
||||||
input_ = Input(shape=IMAGE_SHAPE)
|
|
||||||
x = input_
|
|
||||||
x = self.conv(128)(x)
|
|
||||||
x = self.conv(256)(x)
|
|
||||||
x = self.conv(512)(x)
|
|
||||||
x = self.conv(1024)(x)
|
|
||||||
x = Dense(ENCODER_DIM)(Flatten()(x))
|
|
||||||
x = Dense(4 * 4 * 1024)(x)
|
|
||||||
x = Reshape((4, 4, 1024))(x)
|
|
||||||
x = self.upscale(512)(x)
|
|
||||||
return KerasModel(input_, x)
|
|
||||||
|
|
||||||
def Decoder(self):
|
|
||||||
input_ = Input(shape=(8, 8, 512))
|
|
||||||
x = input_
|
|
||||||
x = self.upscale(256)(x)
|
|
||||||
x = self.upscale(128)(x)
|
|
||||||
x = self.upscale(64)(x)
|
|
||||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
|
||||||
return KerasModel(input_, x)
|
|
|
@ -1,59 +0,0 @@
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
random_transform_args = {
|
|
||||||
'rotation_range': 10,
|
|
||||||
'zoom_range': 0.05,
|
|
||||||
'shift_range': 0.05,
|
|
||||||
'random_flip': 0.4,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, *args):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
generator = TrainingDataGenerator(self.random_transform_args, 160)
|
|
||||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
|
||||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
|
||||||
|
|
||||||
def train_one_step(self, iter, viewer):
|
|
||||||
epoch, warped_A, target_A = next(self.images_A)
|
|
||||||
epoch, warped_B, target_B = next(self.images_B)
|
|
||||||
|
|
||||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
|
||||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
|
||||||
|
|
||||||
self.model._epoch_no += 1
|
|
||||||
|
|
||||||
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), self.model.epoch_no, loss_A, loss_B),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
|
|
||||||
|
|
||||||
def show_sample(self, test_A, test_B):
|
|
||||||
figure_A = numpy.stack([
|
|
||||||
test_A,
|
|
||||||
self.model.autoencoder_A.predict(test_A),
|
|
||||||
self.model.autoencoder_B.predict(test_A),
|
|
||||||
], axis=1)
|
|
||||||
figure_B = numpy.stack([
|
|
||||||
test_B,
|
|
||||||
self.model.autoencoder_B.predict(test_B),
|
|
||||||
self.model.autoencoder_A.predict(test_B),
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
if test_A.shape[0] % 2 == 1:
|
|
||||||
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
|
|
||||||
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
|
|
||||||
|
|
||||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
|
||||||
w = 4
|
|
||||||
h = int( figure.shape[0] / w)
|
|
||||||
figure = figure.reshape((w, h) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
|
|
||||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
|
|
@ -1,8 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """Based on https://reddit.com/u/deepfakes/"""
|
|
||||||
__version__ = '0.1.0'
|
|
||||||
|
|
||||||
from .Model import Model
|
|
||||||
from .Trainer import Trainer
|
|
||||||
from .AutoEncoder import AutoEncoder
|
|
|
@ -1,312 +0,0 @@
|
||||||
#!/usr/bin/python3
|
|
||||||
|
|
||||||
# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
|
||||||
# Based on https://github.com/iperov/OpenDeepFaceSwap for Decoder multiple res block chain
|
|
||||||
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
|
|
||||||
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
|
|
||||||
|
|
||||||
|
|
||||||
import enum
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
||||||
|
|
||||||
from keras.initializers import RandomNormal
|
|
||||||
from keras.layers import Input, Dense, Flatten, Reshape
|
|
||||||
from keras.layers import SeparableConv2D, add
|
|
||||||
from keras.layers.advanced_activations import LeakyReLU
|
|
||||||
from keras.layers.convolutional import Conv2D
|
|
||||||
from keras.layers.core import Activation
|
|
||||||
from keras.models import Model as KerasModel
|
|
||||||
from keras.optimizers import Adam
|
|
||||||
from keras.utils import multi_gpu_model
|
|
||||||
|
|
||||||
from lib.PixelShuffler import PixelShuffler
|
|
||||||
import lib.Serializer
|
|
||||||
from lib.utils import backup_file
|
|
||||||
|
|
||||||
from . import __version__
|
|
||||||
from .instance_normalization import InstanceNormalization
|
|
||||||
|
|
||||||
|
|
||||||
if isinstance(__version__, (list, tuple)):
|
|
||||||
version_str = ".".join([str(n) for n in __version__[1:]])
|
|
||||||
else:
|
|
||||||
version_str = __version__
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
mswindows = sys.platform=="win32"
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderType(enum.Enum):
|
|
||||||
ORIGINAL = "original"
|
|
||||||
SHAOANLU = "shaoanlu"
|
|
||||||
|
|
||||||
|
|
||||||
_kern_init = RandomNormal(0, 0.02)
|
|
||||||
|
|
||||||
|
|
||||||
def inst_norm():
|
|
||||||
return InstanceNormalization()
|
|
||||||
|
|
||||||
|
|
||||||
ENCODER = EncoderType.ORIGINAL
|
|
||||||
|
|
||||||
|
|
||||||
hdf = {'encoderH5': 'encoder_{version_str}{ENCODER.value}.h5'.format(**vars()),
|
|
||||||
'decoder_AH5': 'decoder_A_{version_str}{ENCODER.value}.h5'.format(**vars()),
|
|
||||||
'decoder_BH5': 'decoder_B_{version_str}{ENCODER.value}.h5'.format(**vars())}
|
|
||||||
|
|
||||||
class Model():
|
|
||||||
|
|
||||||
ENCODER_DIM = 1024 # dense layer size
|
|
||||||
IMAGE_SHAPE = 128, 128 # image shape
|
|
||||||
|
|
||||||
assert [n for n in IMAGE_SHAPE if n>=16]
|
|
||||||
|
|
||||||
IMAGE_WIDTH = max(IMAGE_SHAPE)
|
|
||||||
IMAGE_WIDTH = (IMAGE_WIDTH//16 + (1 if (IMAGE_WIDTH%16)>=8 else 0))*16
|
|
||||||
IMAGE_SHAPE = IMAGE_WIDTH, IMAGE_WIDTH, len('BRG') # good to let ppl know what these are...
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_dir, gpus, encoder_type=ENCODER):
|
|
||||||
|
|
||||||
if mswindows:
|
|
||||||
from ctypes import cdll
|
|
||||||
mydll = cdll.LoadLibrary("user32.dll")
|
|
||||||
mydll.SetProcessDPIAware(True)
|
|
||||||
|
|
||||||
self._encoder_type = encoder_type
|
|
||||||
|
|
||||||
self.model_dir = model_dir
|
|
||||||
|
|
||||||
# can't chnage gpu's when the model is initialized no point in making it r/w
|
|
||||||
self._gpus = gpus
|
|
||||||
|
|
||||||
Encoder = getattr(self, "Encoder_{}".format(self._encoder_type.value))
|
|
||||||
Decoder = getattr(self, "Decoder_{}".format(self._encoder_type.value))
|
|
||||||
|
|
||||||
self.encoder = Encoder()
|
|
||||||
self.decoder_A = Decoder()
|
|
||||||
self.decoder_B = Decoder()
|
|
||||||
|
|
||||||
self.initModel()
|
|
||||||
|
|
||||||
|
|
||||||
def initModel(self):
|
|
||||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
|
||||||
|
|
||||||
x = Input(shape=self.IMAGE_SHAPE)
|
|
||||||
|
|
||||||
self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x)))
|
|
||||||
self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x)))
|
|
||||||
|
|
||||||
if self.gpus > 1:
|
|
||||||
self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus)
|
|
||||||
self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus)
|
|
||||||
|
|
||||||
|
|
||||||
self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error')
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, swapped):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
|
|
||||||
from json import JSONDecodeError
|
|
||||||
face_A, face_B = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
|
|
||||||
|
|
||||||
state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals()))
|
|
||||||
ser = lib.Serializer.get_serializer('json')
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(state_dir, 'rb') as fp:
|
|
||||||
state = ser.unmarshal(fp.read().decode('utf-8'))
|
|
||||||
self._epoch_no = state['epoch_no']
|
|
||||||
except IOError as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e.strerror))
|
|
||||||
self._epoch_no = 0
|
|
||||||
except JSONDecodeError as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e.msg))
|
|
||||||
self._epoch_no = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.encoder.load_weights(os.path.join(model_dir, hdf['encoderH5']))
|
|
||||||
self.decoder_A.load_weights(os.path.join(model_dir, face_A))
|
|
||||||
self.decoder_B.load_weights(os.path.join(model_dir, face_B))
|
|
||||||
logger.info('loaded model weights')
|
|
||||||
return True
|
|
||||||
except IOError as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e.strerror))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning('Error loading training info: %s', str(e))
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def converter(self, swap):
|
|
||||||
autoencoder = self.autoencoder_B if not swap else self.autoencoder_A
|
|
||||||
return autoencoder.predict
|
|
||||||
|
|
||||||
|
|
||||||
def conv(self, filters, kernel_size=5, strides=2, **kwargs):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def conv_sep(self, filters, kernel_size=5, strides=2, use_instance_norm=True, **kwargs):
|
|
||||||
def block(x):
|
|
||||||
x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
|
|
||||||
x = Activation("relu")(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def conv_inst_norm(self, filters, kernel_size=3, strides=2, use_instance_norm=True, **kwargs):
|
|
||||||
def block(x):
|
|
||||||
x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x)
|
|
||||||
if use_instance_norm:
|
|
||||||
x = inst_norm()(x)
|
|
||||||
x = Activation("relu")(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def upscale(self, filters, **kwargs):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters * 4, kernel_size=3, padding='same',
|
|
||||||
kernel_initializer=_kern_init)(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def upscale_inst_norm(self, filters, use_instance_norm=True, **kwargs):
|
|
||||||
def block(x):
|
|
||||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False,
|
|
||||||
kernel_initializer=_kern_init, padding='same', **kwargs)(x)
|
|
||||||
if use_instance_norm:
|
|
||||||
x = inst_norm()(x)
|
|
||||||
x = LeakyReLU(0.1)(x)
|
|
||||||
x = PixelShuffler()(x)
|
|
||||||
return x
|
|
||||||
return block
|
|
||||||
|
|
||||||
def Encoder_original(self, **kwargs):
|
|
||||||
impt = Input(shape=self.IMAGE_SHAPE)
|
|
||||||
|
|
||||||
in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4
|
|
||||||
|
|
||||||
x = self.conv(in_conv_filters)(impt)
|
|
||||||
x = self.conv_sep(256)(x)
|
|
||||||
x = self.conv(512)(x)
|
|
||||||
x = self.conv_sep(1024)(x)
|
|
||||||
|
|
||||||
dense_shape = self.IMAGE_SHAPE[0] // 16
|
|
||||||
x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x))
|
|
||||||
x = Dense(dense_shape * dense_shape * 512, kernel_initializer=_kern_init)(x)
|
|
||||||
x = Reshape((dense_shape, dense_shape, 512))(x)
|
|
||||||
x = self.upscale(512)(x)
|
|
||||||
|
|
||||||
return KerasModel(impt, x, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def Encoder_shaoanlu(self, **kwargs):
|
|
||||||
impt = Input(shape=self.IMAGE_SHAPE)
|
|
||||||
|
|
||||||
in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4
|
|
||||||
|
|
||||||
x = Conv2D(in_conv_filters, kernel_size=5, use_bias=False, padding="same")(impt)
|
|
||||||
x = self.conv_inst_norm(in_conv_filters+32, use_instance_norm=False)(x)
|
|
||||||
x = self.conv_inst_norm(256)(x)
|
|
||||||
x = self.conv_inst_norm(512)(x)
|
|
||||||
x = self.conv_inst_norm(1024)(x)
|
|
||||||
|
|
||||||
dense_shape = self.IMAGE_SHAPE[0] // 16
|
|
||||||
x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x))
|
|
||||||
x = Dense(dense_shape * dense_shape * 768, kernel_initializer=_kern_init)(x)
|
|
||||||
x = Reshape((dense_shape, dense_shape, 768))(x)
|
|
||||||
x = self.upscale(512)(x)
|
|
||||||
|
|
||||||
return KerasModel(impt, x, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def Decoder_original(self):
|
|
||||||
decoder_shape = self.IMAGE_SHAPE[0]//8
|
|
||||||
inpt = Input(shape=(decoder_shape, decoder_shape, 512))
|
|
||||||
|
|
||||||
x = self.upscale(384)(inpt)
|
|
||||||
x = self.upscale(256-32)(x)
|
|
||||||
x = self.upscale(self.IMAGE_SHAPE[0])(x)
|
|
||||||
|
|
||||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
|
||||||
|
|
||||||
return KerasModel(inpt, x)
|
|
||||||
|
|
||||||
|
|
||||||
def Decoder_shaoanlu(self):
|
|
||||||
decoder_shape = self.IMAGE_SHAPE[0]//8
|
|
||||||
inpt = Input(shape=(decoder_shape, decoder_shape, 512))
|
|
||||||
|
|
||||||
x = self.upscale_inst_norm(512)(inpt)
|
|
||||||
x = self.upscale_inst_norm(256)(x)
|
|
||||||
x = self.upscale_inst_norm(self.IMAGE_SHAPE[0])(x)
|
|
||||||
|
|
||||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
|
||||||
|
|
||||||
return KerasModel(inpt, x)
|
|
||||||
|
|
||||||
|
|
||||||
def save_weights(self):
|
|
||||||
model_dir = str(self.model_dir)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for model in hdf.values():
|
|
||||||
backup_file(model_dir, model)
|
|
||||||
except NameError:
|
|
||||||
logger.error('backup functionality not available\n')
|
|
||||||
|
|
||||||
state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals()))
|
|
||||||
ser = lib.Serializer.get_serializer('json')
|
|
||||||
try:
|
|
||||||
with open(state_dir, 'wb') as fp:
|
|
||||||
state_json = ser.marshal({
|
|
||||||
'epoch_no' : self._epoch_no
|
|
||||||
})
|
|
||||||
fp.write(state_json.encode('utf-8'))
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(e.strerror)
|
|
||||||
|
|
||||||
logger.info('saving model weights')
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
|
||||||
futures = [executor.submit(getattr(self, mdl_name.rstrip('H5')).save_weights, str(self.model_dir / mdl_H5_fn)) for mdl_name, mdl_H5_fn in hdf.items()]
|
|
||||||
for future in as_completed(futures):
|
|
||||||
future.result()
|
|
||||||
print('.', end='', flush=True)
|
|
||||||
|
|
||||||
logger.info('done')
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gpus(self):
|
|
||||||
return self._gpus
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_name(self):
|
|
||||||
try:
|
|
||||||
return self._model_name
|
|
||||||
except AttributeError:
|
|
||||||
import inspect
|
|
||||||
self._model_name = os.path.dirname(inspect.getmodule(self).__file__).rsplit("_", 1)[1]
|
|
||||||
return self._model_name
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "<{}: ver={}, dense_dim={}, img_size={}>".format(self.model_name,
|
|
||||||
version_str,
|
|
||||||
self.ENCODER_DIM,
|
|
||||||
"x".join([str(n) for n in self.IMAGE_SHAPE[:2]]))
|
|
|
@ -1,84 +0,0 @@
|
||||||
import time
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
|
||||||
|
|
||||||
|
|
||||||
TRANSFORM_PRC = 115.
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer():
|
|
||||||
|
|
||||||
_random_transform_args = {
|
|
||||||
'rotation_range': 10 * (TRANSFORM_PRC * .01),
|
|
||||||
'zoom_range': 0.05 * (TRANSFORM_PRC * .01),
|
|
||||||
'shift_range': 0.05 * (TRANSFORM_PRC * .01),
|
|
||||||
'random_flip': 0.4 * (TRANSFORM_PRC * .01),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, model, fn_A, fn_B, batch_size, *args):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.model = model
|
|
||||||
from timeit import default_timer as clock
|
|
||||||
self._clock = clock
|
|
||||||
|
|
||||||
generator = TrainingDataGenerator(self.random_transform_args, 160, 5, zoom=self.model.IMAGE_SHAPE[0]//64)
|
|
||||||
|
|
||||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
|
||||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
|
||||||
|
|
||||||
self.generator = generator
|
|
||||||
|
|
||||||
|
|
||||||
def train_one_step(self, iter_no, viewer):
|
|
||||||
when = self._clock()
|
|
||||||
_, warped_A, target_A = next(self.images_A)
|
|
||||||
_, warped_B, target_B = next(self.images_B)
|
|
||||||
|
|
||||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
|
||||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
|
||||||
|
|
||||||
self.model._epoch_no += 1
|
|
||||||
|
|
||||||
if isinstance(loss_A, (list, tuple)):
|
|
||||||
print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format(
|
|
||||||
time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A[1], loss_B[1]),
|
|
||||||
end='\r')
|
|
||||||
else:
|
|
||||||
print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format(
|
|
||||||
time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A, loss_B),
|
|
||||||
end='\r')
|
|
||||||
|
|
||||||
if viewer is not None:
|
|
||||||
viewer(self.show_sample(target_A[0:8], target_B[0:8]), "training using {}, bs={}".format(self.model, self.batch_size))
|
|
||||||
|
|
||||||
|
|
||||||
def show_sample(self, test_A, test_B):
|
|
||||||
figure_A = numpy.stack([
|
|
||||||
test_A,
|
|
||||||
self.model.autoencoder_A.predict(test_A),
|
|
||||||
self.model.autoencoder_B.predict(test_A),
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
figure_B = numpy.stack([
|
|
||||||
test_B,
|
|
||||||
self.model.autoencoder_B.predict(test_B),
|
|
||||||
self.model.autoencoder_A.predict(test_B),
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
if (test_A.shape[0] % 2)!=0:
|
|
||||||
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
|
|
||||||
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
|
|
||||||
|
|
||||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
|
||||||
w = 4
|
|
||||||
h = int( figure.shape[0] / w)
|
|
||||||
figure = figure.reshape((w, h) + figure.shape[1:])
|
|
||||||
figure = stack_images(figure)
|
|
||||||
|
|
||||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def random_transform_args(self):
|
|
||||||
return self._random_transform_args
|
|
|
@ -1,8 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
__author__ = """Based on https://reddit.com/u/deepfakes/"""
|
|
||||||
|
|
||||||
from ._version import __version__
|
|
||||||
from .Model import Model
|
|
||||||
from .Trainer import Trainer
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
__version__ = 0, 2, 7
|
|
|
@ -1,145 +0,0 @@
|
||||||
from keras.engine import Layer, InputSpec
|
|
||||||
from keras import initializers, regularizers, constraints
|
|
||||||
from keras import backend as K
|
|
||||||
from keras.utils.generic_utils import get_custom_objects
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceNormalization(Layer):
|
|
||||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
|
||||||
Normalize the activations of the previous layer at each step,
|
|
||||||
i.e. applies a transformation that maintains the mean activation
|
|
||||||
close to 0 and the activation standard deviation close to 1.
|
|
||||||
# Arguments
|
|
||||||
axis: Integer, the axis that should be normalized
|
|
||||||
(typically the features axis).
|
|
||||||
For instance, after a `Conv2D` layer with
|
|
||||||
`data_format="channels_first"`,
|
|
||||||
set `axis=1` in `InstanceNormalization`.
|
|
||||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
|
||||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
|
||||||
epsilon: Small float added to variance to avoid dividing by zero.
|
|
||||||
center: If True, add offset of `beta` to normalized tensor.
|
|
||||||
If False, `beta` is ignored.
|
|
||||||
scale: If True, multiply by `gamma`.
|
|
||||||
If False, `gamma` is not used.
|
|
||||||
When the next layer is linear (also e.g. `nn.relu`),
|
|
||||||
this can be disabled since the scaling
|
|
||||||
will be done by the next layer.
|
|
||||||
beta_initializer: Initializer for the beta weight.
|
|
||||||
gamma_initializer: Initializer for the gamma weight.
|
|
||||||
beta_regularizer: Optional regularizer for the beta weight.
|
|
||||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
|
||||||
beta_constraint: Optional constraint for the beta weight.
|
|
||||||
gamma_constraint: Optional constraint for the gamma weight.
|
|
||||||
# Input shape
|
|
||||||
Arbitrary. Use the keyword argument `input_shape`
|
|
||||||
(tuple of integers, does not include the samples axis)
|
|
||||||
when using this layer as the first layer in a model.
|
|
||||||
# Output shape
|
|
||||||
Same shape as input.
|
|
||||||
# References
|
|
||||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
|
||||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
axis=None,
|
|
||||||
epsilon=1e-3,
|
|
||||||
center=True,
|
|
||||||
scale=True,
|
|
||||||
beta_initializer='zeros',
|
|
||||||
gamma_initializer='ones',
|
|
||||||
beta_regularizer=None,
|
|
||||||
gamma_regularizer=None,
|
|
||||||
beta_constraint=None,
|
|
||||||
gamma_constraint=None,
|
|
||||||
**kwargs):
|
|
||||||
super(InstanceNormalization, self).__init__(**kwargs)
|
|
||||||
self.supports_masking = True
|
|
||||||
self.axis = axis
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.center = center
|
|
||||||
self.scale = scale
|
|
||||||
self.beta_initializer = initializers.get(beta_initializer)
|
|
||||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
|
||||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
|
||||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
|
||||||
self.beta_constraint = constraints.get(beta_constraint)
|
|
||||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
ndim = len(input_shape)
|
|
||||||
if self.axis == 0:
|
|
||||||
raise ValueError('Axis cannot be zero')
|
|
||||||
|
|
||||||
if (self.axis is not None) and (ndim == 2):
|
|
||||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
|
||||||
|
|
||||||
self.input_spec = InputSpec(ndim=ndim)
|
|
||||||
|
|
||||||
if self.axis is None:
|
|
||||||
shape = (1,)
|
|
||||||
else:
|
|
||||||
shape = (input_shape[self.axis],)
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
self.gamma = self.add_weight(shape=shape,
|
|
||||||
name='gamma',
|
|
||||||
initializer=self.gamma_initializer,
|
|
||||||
regularizer=self.gamma_regularizer,
|
|
||||||
constraint=self.gamma_constraint)
|
|
||||||
else:
|
|
||||||
self.gamma = None
|
|
||||||
if self.center:
|
|
||||||
self.beta = self.add_weight(shape=shape,
|
|
||||||
name='beta',
|
|
||||||
initializer=self.beta_initializer,
|
|
||||||
regularizer=self.beta_regularizer,
|
|
||||||
constraint=self.beta_constraint)
|
|
||||||
else:
|
|
||||||
self.beta = None
|
|
||||||
self.built = True
|
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
|
||||||
input_shape = K.int_shape(inputs)
|
|
||||||
reduction_axes = list(range(0, len(input_shape)))
|
|
||||||
|
|
||||||
if (self.axis is not None):
|
|
||||||
del reduction_axes[self.axis]
|
|
||||||
|
|
||||||
del reduction_axes[0]
|
|
||||||
|
|
||||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
|
||||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
|
||||||
normed = (inputs - mean) / stddev
|
|
||||||
|
|
||||||
broadcast_shape = [1] * len(input_shape)
|
|
||||||
if self.axis is not None:
|
|
||||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
|
||||||
normed = normed * broadcast_gamma
|
|
||||||
if self.center:
|
|
||||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
|
||||||
normed = normed + broadcast_beta
|
|
||||||
return normed
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
'axis': self.axis,
|
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'center': self.center,
|
|
||||||
'scale': self.scale,
|
|
||||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
|
||||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
|
||||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
|
||||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
|
||||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
|
||||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
|
||||||
}
|
|
||||||
base_config = super(InstanceNormalization, self).get_config()
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
|
|
@ -23,21 +23,22 @@ class PluginLoader():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_converter(name):
|
def get_converter(name):
|
||||||
""" Return requested converter plugin """
|
""" Return requested converter plugin """
|
||||||
return PluginLoader._import("Convert", "Convert_{0}".format(name))
|
return PluginLoader._import("convert", name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(name):
|
def get_model(name):
|
||||||
""" Return requested model plugin """
|
""" Return requested model plugin """
|
||||||
return PluginLoader._import("Model", "Model_{0}".format(name))
|
return PluginLoader._import("train.model", name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_trainer(name):
|
def get_trainer(name):
|
||||||
""" Return requested trainer plugin """
|
""" Return requested trainer plugin """
|
||||||
return PluginLoader._import("Trainer", "Model_{0}".format(name))
|
return PluginLoader._import("train.trainer", name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(attr, name):
|
def _import(attr, name):
|
||||||
""" Import the plugin's module """
|
""" Import the plugin's module """
|
||||||
|
name = name.replace("-", "_")
|
||||||
ttl = attr.split(".")[-1].title()
|
ttl = attr.split(".")[-1].title()
|
||||||
logger.info("Loading %s from %s plugin...", ttl, name.title())
|
logger.info("Loading %s from %s plugin...", ttl, name.title())
|
||||||
attr = "model" if attr == "Trainer" else attr.lower()
|
attr = "model" if attr == "Trainer" else attr.lower()
|
||||||
|
@ -48,13 +49,23 @@ class PluginLoader():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
""" Return a list of available models """
|
""" Return a list of available models """
|
||||||
models = ()
|
modelpath = os.path.join(os.path.dirname(__file__), "train", "model")
|
||||||
modelpath = os.path.join(os.path.dirname(__file__), "model")
|
models = sorted(item.name.replace(".py", "").replace("_", "-")
|
||||||
for modeldir in next(os.walk(modelpath))[1]:
|
for item in os.scandir(modelpath)
|
||||||
if modeldir[0:6].lower() == 'model_':
|
if not item.name.startswith("_")
|
||||||
models += (modeldir[6:],)
|
and item.name.endswith(".py"))
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_available_converters():
|
||||||
|
""" Return a list of available converters """
|
||||||
|
converter_path = os.path.join(os.path.dirname(__file__), "convert")
|
||||||
|
converters = sorted(item.name.replace(".py", "").replace("_", "-")
|
||||||
|
for item in os.scandir(converter_path)
|
||||||
|
if not item.name.startswith("_")
|
||||||
|
and item.name.endswith(".py"))
|
||||||
|
return converters
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_available_extractors(extractor_type):
|
def get_available_extractors(extractor_type):
|
||||||
""" Return a list of available models """
|
""" Return a list of available models """
|
||||||
|
@ -72,4 +83,4 @@ class PluginLoader():
|
||||||
def get_default_model():
|
def get_default_model():
|
||||||
""" Return the default model """
|
""" Return the default model """
|
||||||
models = PluginLoader.get_available_models()
|
models = PluginLoader.get_available_models()
|
||||||
return 'Original' if 'Original' in models else models[0]
|
return 'original' if 'original' in models else models[0]
|
||||||
|
|
180
plugins/train/_config.py
Normal file
180
plugins/train/_config.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Default configurations for models """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lib.config import FaceswapConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
MASK_TYPES = ["none", "dfaker", "dfl_full"]
|
||||||
|
MASK_INFO = "The mask to be used for training. Select none to not use a mask"
|
||||||
|
COVERAGE_INFO = ("How much of the extracted image to train on. Generally the model is optimized\n"
|
||||||
|
"to the default value. Sensible values to use are:"
|
||||||
|
"\n\t62.5%% spans from eyebrow to eyebrow."
|
||||||
|
"\n\t75.0%% spans from temple to temple."
|
||||||
|
"\n\t87.5%% spans from ear to ear."
|
||||||
|
"\n\t100.0%% is a mugshot.")
|
||||||
|
|
||||||
|
|
||||||
|
class Config(FaceswapConfig):
|
||||||
|
""" Config File for Models """
|
||||||
|
|
||||||
|
def set_defaults(self):
|
||||||
|
""" Set the default values for config """
|
||||||
|
logger.debug("Setting defaults")
|
||||||
|
# << GLOBAL OPTIONS >> #
|
||||||
|
section = "global"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="Options that apply to all models")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="icnr_init", datatype=bool, default=False,
|
||||||
|
info="Use ICNR Kernel Initializer for upscaling.\nThis can help reduce the "
|
||||||
|
"'checkerboard effect' when upscaling the image.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="subpixel_upscaling", datatype=bool, default=False,
|
||||||
|
info="Use subpixel upscaling rather than pixel shuffler.\n"
|
||||||
|
"Might increase speed at cost of VRAM")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="reflect_padding", datatype=bool, default=False,
|
||||||
|
info="Use reflect padding rather than zero padding.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="dssim_mask_loss", datatype=bool, default=True,
|
||||||
|
info="If using a mask, Use DSSIM loss for Mask training rather than Mean Absolute "
|
||||||
|
"Error\nMay increase overall quality.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="penalized_mask_loss", datatype=bool, default=True,
|
||||||
|
info="If using a mask, Use Penalized loss for Mask training. Can stack with DSSIM.\n"
|
||||||
|
"May increase overall quality.")
|
||||||
|
|
||||||
|
# << DFAKER OPTIONS >> #
|
||||||
|
section = "model.dfaker"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="Dfaker Model (Adapted from https://github.com/dfaker/df)")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="dfaker",
|
||||||
|
choices=MASK_TYPES, info=MASK_INFO)
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=100.0, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
||||||
|
|
||||||
|
# << DFL MODEL OPTIONS >> #
|
||||||
|
section = "model.dfl_h128"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="DFL H128 Model (Adapted from "
|
||||||
|
"https://github.com/iperov/DeepFaceLab)")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="lowmem", datatype=bool, default=False,
|
||||||
|
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
|
||||||
|
"with a changed lowmem mode are not compatible with each other.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="dfl_full",
|
||||||
|
choices=MASK_TYPES, info=MASK_INFO)
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
||||||
|
|
||||||
|
# << IAE MODEL OPTIONS >> #
|
||||||
|
section = "model.iae"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="Intermediate Auto Encoder. Based on Original Model, uses "
|
||||||
|
"intermediate layers to try to better get details")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="dssim_loss", datatype=bool, default=False,
|
||||||
|
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
|
||||||
|
"May increase overall quality.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="none",
|
||||||
|
choices=MASK_TYPES, info=MASK_INFO)
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
||||||
|
|
||||||
|
# << ORIGINAL MODEL OPTIONS >> #
|
||||||
|
section = "model.original"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="Original Faceswap Model")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="lowmem", datatype=bool, default=False,
|
||||||
|
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
|
||||||
|
"with a changed lowmem mode are not compatible with each other.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="dssim_loss", datatype=bool, default=False,
|
||||||
|
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
|
||||||
|
"May increase overall quality.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="none",
|
||||||
|
choices=MASK_TYPES, info=MASK_INFO)
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
||||||
|
|
||||||
|
# << UNBALANCED MODEL OPTIONS >> #
|
||||||
|
section = "model.unbalanced"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="An unbalanced model with adjustable input size options.\n"
|
||||||
|
"This is an unbalanced model so b>a swaps may not work well")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="lowmem", datatype=bool, default=False,
|
||||||
|
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
|
||||||
|
"with a changed lowmem mode are not compatible with each other. NB: lowmem will "
|
||||||
|
"override cutom nodes and complexity settings.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="dssim_loss", datatype=bool, default=False,
|
||||||
|
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
|
||||||
|
"May increase overall quality.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="none",
|
||||||
|
choices=MASK_TYPES, info=MASK_INFO)
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="nodes", datatype=int, default=1024, rounding=64,
|
||||||
|
min_max=(512, 4096),
|
||||||
|
info="Number of nodes for decoder. Don't change this unless you "
|
||||||
|
"know what you are doing!")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="complexity_encoder", datatype=int, default=128,
|
||||||
|
rounding=16, min_max=(64, 1024),
|
||||||
|
info="Encoder Convolution Layer Complexity. sensible ranges: "
|
||||||
|
"128 to 160")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="complexity_decoder_a", datatype=int, default=384,
|
||||||
|
rounding=16, min_max=(64, 1024),
|
||||||
|
info="Decoder A Complexity.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="complexity_decoder_b", datatype=int, default=512,
|
||||||
|
rounding=16, min_max=(64, 1024),
|
||||||
|
info="Decoder B Complexity.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="input_size", datatype=int, default=128,
|
||||||
|
rounding=64, min_max=(64, 512),
|
||||||
|
info="Resolution (in pixels) of the image to train on.\n"
|
||||||
|
"BE AWARE Larger resolution will dramatically increase"
|
||||||
|
"VRAM requirements.\n"
|
||||||
|
"Make sure your resolution is divisible by 64 (e.g. 64, 128, 256 etc.).\n"
|
||||||
|
"NB: Your faceset must be at least 1.6x larger than your required input size.\n"
|
||||||
|
" (e.g. 160 is the maximum input size for a 256x256 faceset)")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
||||||
|
|
||||||
|
# << VILLAIN MODEL OPTIONS >> #
|
||||||
|
section = "model.villain"
|
||||||
|
self.add_section(title=section,
|
||||||
|
info="A Higher resolution version of the Original "
|
||||||
|
"Model by VillainGuy.\n"
|
||||||
|
"Extremely VRAM heavy. Full model requires 9GB+ for batchsize 16")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="lowmem", datatype=bool, default=False,
|
||||||
|
info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models "
|
||||||
|
"with a changed lowmem mode are not compatible with each other.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="dssim_loss", datatype=bool, default=False,
|
||||||
|
info="Use DSSIM for Loss rather than Mean Absolute Error\n"
|
||||||
|
"May increase overall quality.")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="mask_type", datatype=str, default="none",
|
||||||
|
choices=["none", "dfaker", "dfl_full"],
|
||||||
|
info="The mask to be used for training. Select none to not use a mask")
|
||||||
|
self.add_item(
|
||||||
|
section=section, title="coverage", datatype=float, default=62.5, rounding=1,
|
||||||
|
min_max=(62.5, 100.0), info=COVERAGE_INFO)
|
0
plugins/train/model/__init__.py
Normal file
0
plugins/train/model/__init__.py
Normal file
586
plugins/train/model/_base.py
Normal file
586
plugins/train/model/_base.py
Normal file
|
@ -0,0 +1,586 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Base class for Models. ALL Models should at least inherit from this class
|
||||||
|
|
||||||
|
When inheriting model_data should be a list of NNMeta objects.
|
||||||
|
See the class for details.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
from keras import losses
|
||||||
|
from keras.models import load_model
|
||||||
|
from keras.optimizers import Adam
|
||||||
|
from keras.utils import get_custom_objects, multi_gpu_model
|
||||||
|
|
||||||
|
from lib import Serializer
|
||||||
|
from lib.model.losses import DSSIMObjective, PenalizedLoss
|
||||||
|
from lib.model.nn_blocks import NNBlocks
|
||||||
|
from lib.multithreading import MultiThread
|
||||||
|
from plugins.train._config import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
_CONFIG = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBase():
|
||||||
|
""" Base class that all models should inherit from """
|
||||||
|
def __init__(self,
|
||||||
|
model_dir,
|
||||||
|
gpus,
|
||||||
|
no_logs=False,
|
||||||
|
warp_to_landmarks=False,
|
||||||
|
no_flip=False,
|
||||||
|
training_image_size=256,
|
||||||
|
alignments_paths=None,
|
||||||
|
preview_scale=100,
|
||||||
|
input_shape=None,
|
||||||
|
encoder_dim=None,
|
||||||
|
trainer="original",
|
||||||
|
predict=False):
|
||||||
|
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
|
||||||
|
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
|
||||||
|
"input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus,
|
||||||
|
training_image_size, alignments_paths, preview_scale, input_shape,
|
||||||
|
encoder_dim)
|
||||||
|
self.predict = predict
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.gpus = gpus
|
||||||
|
self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"],
|
||||||
|
use_icnr_init=self.config["icnr_init"],
|
||||||
|
use_reflect_padding=self.config["reflect_padding"])
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.output_shape = None # set after model is compiled
|
||||||
|
self.encoder_dim = encoder_dim
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
self.state = State(self.model_dir, self.name, no_logs, training_image_size)
|
||||||
|
self.load_state_info()
|
||||||
|
|
||||||
|
self.networks = dict() # Networks for the model
|
||||||
|
self.predictors = dict() # Predictors for model
|
||||||
|
self.history = dict() # Loss history per save iteration)
|
||||||
|
|
||||||
|
# Training information specific to the model should be placed in this
|
||||||
|
# dict for reference by the trainer.
|
||||||
|
self.training_opts = {"alignments": alignments_paths,
|
||||||
|
"preview_scaling": preview_scale / 100,
|
||||||
|
"warp_to_landmarks": warp_to_landmarks,
|
||||||
|
"no_flip": no_flip}
|
||||||
|
|
||||||
|
self.build()
|
||||||
|
self.set_training_data()
|
||||||
|
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self):
|
||||||
|
""" Return config dict for current plugin """
|
||||||
|
global _CONFIG # pylint: disable=global-statement
|
||||||
|
if not _CONFIG:
|
||||||
|
model_name = ".".join(self.__module__.split(".")[-2:])
|
||||||
|
logger.debug("Loading config for: %s", model_name)
|
||||||
|
_CONFIG = Config(model_name).config_dict
|
||||||
|
return _CONFIG
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
""" Set the model name based on the subclass """
|
||||||
|
basename = os.path.basename(sys.modules[self.__module__].__file__)
|
||||||
|
retval = os.path.splitext(basename)[0].lower()
|
||||||
|
logger.debug("model name: '%s'", retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def set_training_data(self):
|
||||||
|
""" Override to set model specific training data.
|
||||||
|
|
||||||
|
super() this method for defaults otherwise be sure to add """
|
||||||
|
logger.debug("Setting training data")
|
||||||
|
self.training_opts["training_size"] = self.state.training_size
|
||||||
|
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
|
||||||
|
self.training_opts["mask_type"] = self.config.get("mask_type", None)
|
||||||
|
self.training_opts["coverage_ratio"] = self.config.get("coverage", 62.5) / 100
|
||||||
|
self.training_opts["preview_images"] = 14
|
||||||
|
logger.debug("Set training data: %s", self.training_opts)
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
""" Build the model. Override for custom build methods """
|
||||||
|
self.add_networks()
|
||||||
|
self.load_models(swapped=False)
|
||||||
|
self.build_autoencoders()
|
||||||
|
self.log_summary()
|
||||||
|
self.compile_predictors()
|
||||||
|
|
||||||
|
def build_autoencoders(self):
|
||||||
|
""" Override for Model Specific autoencoder builds
|
||||||
|
|
||||||
|
NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names
|
||||||
|
are expected:
|
||||||
|
face (the input for image)
|
||||||
|
mask (the input for mask if it is used)
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add_networks(self):
|
||||||
|
""" Override to add neural networks """
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_state_info(self):
|
||||||
|
""" Load the input shape from state file if it exists """
|
||||||
|
logger.debug("Loading Input Shape from State file")
|
||||||
|
if not self.state.inputs:
|
||||||
|
logger.debug("No input shapes saved. Using model config")
|
||||||
|
return
|
||||||
|
if not self.state.face_shapes:
|
||||||
|
logger.warning("Input shapes stored in State file, but no matches for 'face'."
|
||||||
|
"Using model config")
|
||||||
|
return
|
||||||
|
input_shape = self.state.face_shapes[0]
|
||||||
|
logger.debug("Setting input shape from state file: %s", input_shape)
|
||||||
|
self.input_shape = input_shape
|
||||||
|
|
||||||
|
def add_network(self, network_type, side, network):
|
||||||
|
""" Add a NNMeta object """
|
||||||
|
logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network)
|
||||||
|
filename = "{}_{}".format(self.name, network_type.lower())
|
||||||
|
name = network_type.lower()
|
||||||
|
if side:
|
||||||
|
side = side.lower()
|
||||||
|
filename += "_{}".format(side.upper())
|
||||||
|
name += "_{}".format(side)
|
||||||
|
filename += ".h5"
|
||||||
|
logger.debug("name: '%s', filename: '%s'", name, filename)
|
||||||
|
self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network)
|
||||||
|
|
||||||
|
def add_predictor(self, side, model):
|
||||||
|
""" Add a predictor to the predictors dictionary """
|
||||||
|
logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
|
||||||
|
if self.gpus > 1:
|
||||||
|
logger.debug("Converting to multi-gpu: side %s", side)
|
||||||
|
model = multi_gpu_model(model, self.gpus)
|
||||||
|
self.predictors[side] = model
|
||||||
|
if not self.state.inputs:
|
||||||
|
self.store_input_shapes(model)
|
||||||
|
if not self.output_shape:
|
||||||
|
self.set_output_shape(model)
|
||||||
|
|
||||||
|
def store_input_shapes(self, model):
|
||||||
|
""" Store the input and output shapes to state """
|
||||||
|
logger.debug("Adding input shapes to state for model")
|
||||||
|
inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs}
|
||||||
|
if not any(inp for inp in inputs.keys() if inp.startswith("face")):
|
||||||
|
raise ValueError("No input named 'face' was found. Check your input naming. "
|
||||||
|
"Current input names: {}".format(inputs))
|
||||||
|
self.state.inputs = inputs
|
||||||
|
logger.debug("Added input shapes: %s", self.state.inputs)
|
||||||
|
|
||||||
|
def set_output_shape(self, model):
|
||||||
|
""" Set the output shape for use in training and convert """
|
||||||
|
logger.debug("Setting output shape")
|
||||||
|
out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs]
|
||||||
|
if not out:
|
||||||
|
raise ValueError("No outputs found! Check your model.")
|
||||||
|
self.output_shape = tuple(out[0])
|
||||||
|
logger.debug("Added output shape: %s", self.output_shape)
|
||||||
|
|
||||||
|
def compile_predictors(self):
|
||||||
|
""" Compile the predictors """
|
||||||
|
logger.debug("Compiling Predictors")
|
||||||
|
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0)
|
||||||
|
|
||||||
|
for side, model in self.predictors.items():
|
||||||
|
loss_names = ["loss"]
|
||||||
|
loss_funcs = [self.loss_function(side)]
|
||||||
|
mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
|
||||||
|
if mask:
|
||||||
|
loss_names.insert(0, "mask_loss")
|
||||||
|
loss_funcs.insert(0, self.mask_loss_function(mask[0], side))
|
||||||
|
model.compile(optimizer=optimizer, loss=loss_funcs)
|
||||||
|
|
||||||
|
if len(loss_names) > 1:
|
||||||
|
loss_names.insert(0, "total_loss")
|
||||||
|
self.state.add_session_loss_names(side, loss_names)
|
||||||
|
self.history[side] = list()
|
||||||
|
logger.debug("Compiled Predictors. Losses: %s", loss_names)
|
||||||
|
|
||||||
|
def loss_function(self, side):
|
||||||
|
""" Set the loss function """
|
||||||
|
if self.config.get("dssim_loss", False):
|
||||||
|
if side == "a" and not self.predict:
|
||||||
|
logger.verbose("Using DSSIM Loss")
|
||||||
|
loss_func = DSSIMObjective()
|
||||||
|
else:
|
||||||
|
if side == "a" and not self.predict:
|
||||||
|
logger.verbose("Using Mean Absolute Error Loss")
|
||||||
|
loss_func = losses.mean_absolute_error
|
||||||
|
logger.debug(loss_func)
|
||||||
|
return loss_func
|
||||||
|
|
||||||
|
def mask_loss_function(self, mask, side):
|
||||||
|
""" Set the loss function for masks
|
||||||
|
Side is input so we only log once """
|
||||||
|
if self.config.get("dssim_mask_loss", False):
|
||||||
|
if side == "a" and not self.predict:
|
||||||
|
logger.verbose("Using DSSIM Loss for mask")
|
||||||
|
mask_loss_func = DSSIMObjective()
|
||||||
|
else:
|
||||||
|
if side == "a" and not self.predict:
|
||||||
|
logger.verbose("Using Mean Absolute Error Loss for mask")
|
||||||
|
mask_loss_func = losses.mean_absolute_error
|
||||||
|
|
||||||
|
if self.config.get("penalized_mask_loss", False):
|
||||||
|
if side == "a" and not self.predict:
|
||||||
|
logger.verbose("Using Penalized Loss for mask")
|
||||||
|
mask_loss_func = PenalizedLoss(mask, mask_loss_func)
|
||||||
|
logger.debug(mask_loss_func)
|
||||||
|
return mask_loss_func
|
||||||
|
|
||||||
|
def converter(self, swap):
|
||||||
|
""" Converter for autoencoder models """
|
||||||
|
logger.debug("Getting Converter: (swap: %s)", swap)
|
||||||
|
if swap:
|
||||||
|
retval = self.predictors["a"].predict
|
||||||
|
else:
|
||||||
|
retval = self.predictors["b"].predict
|
||||||
|
logger.debug("Got Converter: %s", retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iterations(self):
|
||||||
|
"Get current training iteration number"
|
||||||
|
return self.state.iterations
|
||||||
|
|
||||||
|
def map_models(self, swapped):
|
||||||
|
""" Map the models for A/B side for swapping """
|
||||||
|
logger.debug("Map models: (swapped: %s)", swapped)
|
||||||
|
models_map = {"a": dict(), "b": dict()}
|
||||||
|
sides = ("a", "b") if not swapped else ("b", "a")
|
||||||
|
for network in self.networks.values():
|
||||||
|
if network.side == sides[0]:
|
||||||
|
models_map["a"][network.type] = network.filename
|
||||||
|
if network.side == sides[1]:
|
||||||
|
models_map["b"][network.type] = network.filename
|
||||||
|
logger.debug("Mapped models: (models_map: %s)", models_map)
|
||||||
|
return models_map
|
||||||
|
|
||||||
|
def log_summary(self):
|
||||||
|
""" Verbose log the model summaries """
|
||||||
|
if self.predict:
|
||||||
|
return
|
||||||
|
for side in sorted(list(self.predictors.keys())):
|
||||||
|
logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
|
||||||
|
self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x))
|
||||||
|
for name, nnmeta in self.networks.items():
|
||||||
|
if nnmeta.side is not None and nnmeta.side != side:
|
||||||
|
continue
|
||||||
|
logger.verbose("%s:", name.title())
|
||||||
|
nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x))
|
||||||
|
|
||||||
|
def load_models(self, swapped):
|
||||||
|
""" Load models from file """
|
||||||
|
logger.debug("Load model: (swapped: %s)", swapped)
|
||||||
|
model_mapping = self.map_models(swapped)
|
||||||
|
for network in self.networks.values():
|
||||||
|
if not network.side:
|
||||||
|
is_loaded = network.load(predict=self.predict)
|
||||||
|
else:
|
||||||
|
is_loaded = network.load(fullpath=model_mapping[network.side][network.type],
|
||||||
|
predict=self.predict)
|
||||||
|
if not is_loaded:
|
||||||
|
break
|
||||||
|
if is_loaded:
|
||||||
|
logger.info("Loaded model from disk: '%s'", self.model_dir)
|
||||||
|
return is_loaded
|
||||||
|
|
||||||
|
def save_models(self):
|
||||||
|
""" Backup and save the models """
|
||||||
|
logger.debug("Backing up and saving models")
|
||||||
|
should_backup = self.get_save_averages()
|
||||||
|
save_threads = list()
|
||||||
|
for network in self.networks.values():
|
||||||
|
name = "save_{}".format(network.name)
|
||||||
|
save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup))
|
||||||
|
save_threads.append(MultiThread(self.state.save,
|
||||||
|
name="save_state", should_backup=should_backup))
|
||||||
|
for thread in save_threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in save_threads:
|
||||||
|
if thread.has_error:
|
||||||
|
logger.error(thread.errors[0])
|
||||||
|
thread.join()
|
||||||
|
# Put in a line break to avoid jumbled console
|
||||||
|
print("\n")
|
||||||
|
logger.info("saved models")
|
||||||
|
|
||||||
|
def get_save_averages(self):
|
||||||
|
""" Return the loss averages since last save and reset historical losses
|
||||||
|
|
||||||
|
This protects against model corruption by only backing up the model
|
||||||
|
if any of the loss values have fallen.
|
||||||
|
TODO This is not a perfect system. If the model corrupts on save_iteration - 1
|
||||||
|
then model may still backup
|
||||||
|
"""
|
||||||
|
logger.debug("Getting Average loss since last save")
|
||||||
|
avgs = dict()
|
||||||
|
backup = True
|
||||||
|
|
||||||
|
for side, loss in self.history.items():
|
||||||
|
if not loss:
|
||||||
|
backup = False
|
||||||
|
break
|
||||||
|
|
||||||
|
avgs[side] = sum(loss) / len(loss)
|
||||||
|
self.history[side] = list() # Reset historical loss
|
||||||
|
|
||||||
|
if not self.state.lowest_avg_loss.get(side, None):
|
||||||
|
logger.debug("Setting initial save iteration loss average for '%s': %s",
|
||||||
|
side, avgs[side])
|
||||||
|
self.state.lowest_avg_loss[side] = avgs[side]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if backup:
|
||||||
|
# Only run this if backup is true. All losses must have dropped for a valid backup
|
||||||
|
backup = self.check_loss_drop(side, avgs[side])
|
||||||
|
|
||||||
|
logger.debug("Lowest historical save iteration loss average: %s",
|
||||||
|
self.state.lowest_avg_loss)
|
||||||
|
logger.debug("Average loss since last save: %s", avgs)
|
||||||
|
|
||||||
|
if backup: # Update lowest loss values to the state
|
||||||
|
for side, avg_loss in avgs.items():
|
||||||
|
logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
|
||||||
|
self.state.lowest_avg_loss[side] = avg_loss
|
||||||
|
|
||||||
|
logger.debug("Backing up: %s", backup)
|
||||||
|
return backup
|
||||||
|
|
||||||
|
def check_loss_drop(self, side, avg):
|
||||||
|
""" Check whether total loss has dropped since lowest loss """
|
||||||
|
if avg < self.state.lowest_avg_loss[side]:
|
||||||
|
logger.debug("Loss for '%s' has dropped", side)
|
||||||
|
return True
|
||||||
|
logger.debug("Loss for '%s' has not dropped", side)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NNMeta():
|
||||||
|
""" Class to hold a neural network and it's meta data
|
||||||
|
|
||||||
|
filename: The full path and filename of the model file for this network.
|
||||||
|
type: The type of network. For networks that can be swapped
|
||||||
|
The type should be identical for the corresponding
|
||||||
|
A and B networks, and should be unique for every A/B pair.
|
||||||
|
Otherwise the type should be completely unique.
|
||||||
|
side: A, B or None. Used to identify which networks can
|
||||||
|
be swapped.
|
||||||
|
network: Define network to this.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename, network_type, side, network):
|
||||||
|
logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
|
||||||
|
"network: %s", self.__class__.__name__, filename, network_type,
|
||||||
|
side, network)
|
||||||
|
self.filename = filename
|
||||||
|
self.type = network_type.lower()
|
||||||
|
self.side = side
|
||||||
|
self.name = self.set_name()
|
||||||
|
self.network = network
|
||||||
|
self.network.name = self.name
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def set_name(self):
|
||||||
|
""" Set the network name """
|
||||||
|
name = self.type
|
||||||
|
if self.side:
|
||||||
|
name += "_{}".format(self.side)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def load(self, fullpath=None, predict=False):
|
||||||
|
""" Load model """
|
||||||
|
fullpath = fullpath if fullpath else self.filename
|
||||||
|
logger.debug("Loading model: '%s'", fullpath)
|
||||||
|
try:
|
||||||
|
network = load_model(self.filename, custom_objects=get_custom_objects())
|
||||||
|
except ValueError as err:
|
||||||
|
if str(err).lower().startswith("cannot create group in read only mode"):
|
||||||
|
self.convert_legacy_weights()
|
||||||
|
return True
|
||||||
|
if predict:
|
||||||
|
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
|
||||||
|
logger.warning("Failed loading existing training data. Generating new models")
|
||||||
|
logger.debug("Exception: %s", str(err))
|
||||||
|
return False
|
||||||
|
except OSError as err: # pylint: disable=broad-except
|
||||||
|
if predict:
|
||||||
|
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
|
||||||
|
logger.warning("Failed loading existing training data. Generating new models")
|
||||||
|
logger.debug("Exception: %s", str(err))
|
||||||
|
return False
|
||||||
|
self.network = network # Update network with saved model
|
||||||
|
self.network.name = self.type
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save(self, fullpath=None, should_backup=False):
|
||||||
|
""" Save model """
|
||||||
|
fullpath = fullpath if fullpath else self.filename
|
||||||
|
if should_backup:
|
||||||
|
self.backup(fullpath=fullpath)
|
||||||
|
logger.debug("Saving model: '%s'", fullpath)
|
||||||
|
self.network.save(fullpath)
|
||||||
|
|
||||||
|
def backup(self, fullpath=None):
|
||||||
|
""" Backup Model """
|
||||||
|
origfile = fullpath if fullpath else self.filename
|
||||||
|
backupfile = origfile + ".bk"
|
||||||
|
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
|
||||||
|
if os.path.exists(backupfile):
|
||||||
|
os.remove(backupfile)
|
||||||
|
if os.path.exists(origfile):
|
||||||
|
os.rename(origfile, backupfile)
|
||||||
|
|
||||||
|
def convert_legacy_weights(self):
|
||||||
|
""" Convert legacy weights files to hold the model topology """
|
||||||
|
logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
|
||||||
|
self.network.load_weights(self.filename)
|
||||||
|
self.save(should_backup=False)
|
||||||
|
self.network.name = self.type
|
||||||
|
|
||||||
|
|
||||||
|
class State():
|
||||||
|
""" Class to hold the model's current state and autoencoder structure """
|
||||||
|
def __init__(self, model_dir, model_name, no_logs, training_image_size):
|
||||||
|
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
|
||||||
|
"training_image_size: '%s'", self.__class__.__name__, model_dir,
|
||||||
|
model_name, no_logs, training_image_size)
|
||||||
|
self.serializer = Serializer.get_serializer("json")
|
||||||
|
filename = "{}_state.{}".format(model_name, self.serializer.ext)
|
||||||
|
self.filename = str(model_dir / filename)
|
||||||
|
self.iterations = 0
|
||||||
|
self.session_iterations = 0
|
||||||
|
self.training_size = training_image_size
|
||||||
|
self.sessions = dict()
|
||||||
|
self.lowest_avg_loss = dict()
|
||||||
|
self.inputs = dict()
|
||||||
|
self.config = dict()
|
||||||
|
self.load()
|
||||||
|
self.session_id = self.new_session_id()
|
||||||
|
self.create_new_session(no_logs)
|
||||||
|
logger.debug("Initialized %s:", self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def face_shapes(self):
|
||||||
|
""" Return a list of stored face shape inputs """
|
||||||
|
return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_shapes(self):
|
||||||
|
""" Return a list of stored mask shape inputs """
|
||||||
|
return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_names(self):
|
||||||
|
""" Return the loss names for this session """
|
||||||
|
return self.sessions[self.session_id]["loss_names"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_session(self):
|
||||||
|
""" Return the current session dict """
|
||||||
|
return self.sessions[self.session_id]
|
||||||
|
|
||||||
|
def new_session_id(self):
|
||||||
|
""" Return new session_id """
|
||||||
|
if not self.sessions:
|
||||||
|
session_id = 1
|
||||||
|
else:
|
||||||
|
session_id = max(int(key) for key in self.sessions.keys()) + 1
|
||||||
|
logger.debug(session_id)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def create_new_session(self, no_logs):
|
||||||
|
""" Create a new session """
|
||||||
|
logger.debug("Creating new session. id: %s", self.session_id)
|
||||||
|
self.sessions[self.session_id] = {"timestamp": time.time(),
|
||||||
|
"no_logs": no_logs,
|
||||||
|
"loss_names": dict(),
|
||||||
|
"batchsize": 0,
|
||||||
|
"iterations": 0}
|
||||||
|
|
||||||
|
def add_session_loss_names(self, side, loss_names):
|
||||||
|
""" Add the session loss names to the sessions dictionary """
|
||||||
|
logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
|
||||||
|
self.sessions[self.session_id]["loss_names"][side] = loss_names
|
||||||
|
|
||||||
|
def add_session_batchsize(self, batchsize):
|
||||||
|
""" Add the session batchsize to the sessions dictionary """
|
||||||
|
logger.debug("Adding session batchsize: %s", batchsize)
|
||||||
|
self.sessions[self.session_id]["batchsize"] = batchsize
|
||||||
|
|
||||||
|
def increment_iterations(self):
|
||||||
|
""" Increment total and session iterations """
|
||||||
|
self.iterations += 1
|
||||||
|
self.sessions[self.session_id]["iterations"] += 1
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
""" Load state file """
|
||||||
|
logger.debug("Loading State")
|
||||||
|
try:
|
||||||
|
with open(self.filename, "rb") as inp:
|
||||||
|
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
|
||||||
|
self.sessions = state.get("sessions", dict())
|
||||||
|
self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
|
||||||
|
self.iterations = state.get("iterations", 0)
|
||||||
|
self.training_size = state.get("training_size", 256)
|
||||||
|
self.inputs = state.get("inputs", dict())
|
||||||
|
self.config = state.get("config", dict())
|
||||||
|
logger.debug("Loaded state: %s", state)
|
||||||
|
self.replace_config()
|
||||||
|
except IOError as err:
|
||||||
|
logger.warning("No existing state file found. Generating.")
|
||||||
|
logger.debug("IOError: %s", str(err))
|
||||||
|
except JSONDecodeError as err:
|
||||||
|
logger.debug("JSONDecodeError: %s:", str(err))
|
||||||
|
|
||||||
|
def save(self, should_backup=False):
|
||||||
|
""" Save iteration number to state file """
|
||||||
|
logger.debug("Saving State")
|
||||||
|
if should_backup:
|
||||||
|
self.backup()
|
||||||
|
try:
|
||||||
|
with open(self.filename, "wb") as out:
|
||||||
|
state = {"sessions": self.sessions,
|
||||||
|
"lowest_avg_loss": self.lowest_avg_loss,
|
||||||
|
"iterations": self.iterations,
|
||||||
|
"inputs": self.inputs,
|
||||||
|
"training_size": self.training_size,
|
||||||
|
"config": _CONFIG}
|
||||||
|
state_json = self.serializer.marshal(state)
|
||||||
|
out.write(state_json.encode("utf-8"))
|
||||||
|
except IOError as err:
|
||||||
|
logger.error("Unable to save model state: %s", str(err.strerror))
|
||||||
|
logger.debug("Saved State")
|
||||||
|
|
||||||
|
def backup(self):
|
||||||
|
""" Backup state file """
|
||||||
|
origfile = self.filename
|
||||||
|
backupfile = origfile + ".bk"
|
||||||
|
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
|
||||||
|
if os.path.exists(backupfile):
|
||||||
|
os.remove(backupfile)
|
||||||
|
if os.path.exists(origfile):
|
||||||
|
os.rename(origfile, backupfile)
|
||||||
|
|
||||||
|
def replace_config(self):
|
||||||
|
""" Replace the loaded config with the one contained within the state file """
|
||||||
|
global _CONFIG # pylint: disable=global-statement
|
||||||
|
# Add any new items to state config for legacy purposes
|
||||||
|
for key, val in _CONFIG.items():
|
||||||
|
if key not in self.config.keys():
|
||||||
|
logger.info("Adding new config item to state file: '%s': '%s'", key, val)
|
||||||
|
self.config[key] = val
|
||||||
|
logger.debug("Replacing config. Old config: %s", _CONFIG)
|
||||||
|
_CONFIG = self.config
|
||||||
|
logger.debug("Replaced config. New config: %s", _CONFIG)
|
||||||
|
logger.info("Using configuration saved in state file")
|
62
plugins/train/model/dfaker.py
Normal file
62
plugins/train/model/dfaker.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" DFaker Model
|
||||||
|
Based on the dfaker model: https://github.com/dfaker """
|
||||||
|
|
||||||
|
|
||||||
|
from keras.initializers import RandomNormal
|
||||||
|
from keras.layers import Conv2D, Input
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from .original import logger, Model as OriginalModel
|
||||||
|
|
||||||
|
|
||||||
|
class Model(OriginalModel):
|
||||||
|
""" Improved Autoeencoder Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
kwargs["input_shape"] = (64, 64, 3)
|
||||||
|
kwargs["encoder_dim"] = 1024
|
||||||
|
self.kernel_initializer = RandomNormal(0, 0.02)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def build_autoencoders(self):
|
||||||
|
""" Initialize Dfaker model """
|
||||||
|
logger.debug("Initializing model")
|
||||||
|
inputs = [Input(shape=self.input_shape, name="face")]
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
mask_shape = (self.input_shape[0] * 2, self.input_shape[1] * 2, 1)
|
||||||
|
inputs.append(Input(shape=mask_shape, name="mask"))
|
||||||
|
|
||||||
|
for side in ("a", "b"):
|
||||||
|
decoder = self.networks["decoder_{}".format(side)].network
|
||||||
|
output = decoder(self.networks["encoder"].network(inputs[0]))
|
||||||
|
autoencoder = KerasModel(inputs, output)
|
||||||
|
self.add_predictor(side, autoencoder)
|
||||||
|
logger.debug("Initialized model")
|
||||||
|
|
||||||
|
def decoder(self):
|
||||||
|
""" Decoder Network """
|
||||||
|
input_ = Input(shape=(8, 8, 512))
|
||||||
|
var_x = input_
|
||||||
|
|
||||||
|
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True)
|
||||||
|
var_x = self.blocks.res_block(var_x, 512, kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True)
|
||||||
|
var_x = self.blocks.res_block(var_x, 256, kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, 128, res_block_follows=True)
|
||||||
|
var_x = self.blocks.res_block(var_x, 128, kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, 64)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = input_
|
||||||
|
var_y = self.blocks.upscale(var_y, 512)
|
||||||
|
var_y = self.blocks.upscale(var_y, 256)
|
||||||
|
var_y = self.blocks.upscale(var_y, 128)
|
||||||
|
var_y = self.blocks.upscale(var_y, 64)
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel([input_], outputs=outputs)
|
53
plugins/train/model/dfl_h128.py
Normal file
53
plugins/train/model/dfl_h128.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" DeepFakesLab H128 Model
|
||||||
|
Based on https://github.com/iperov/DeepFaceLab
|
||||||
|
"""
|
||||||
|
|
||||||
|
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from .original import logger, Model as OriginalModel
|
||||||
|
|
||||||
|
|
||||||
|
class Model(OriginalModel):
|
||||||
|
""" Low Memory version of Original Faceswap Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
|
||||||
|
kwargs["input_shape"] = (128, 128, 3)
|
||||||
|
kwargs["encoder_dim"] = 256 if self.config["lowmem"] else 512
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def encoder(self):
|
||||||
|
""" DFL H128 Encoder """
|
||||||
|
input_ = Input(shape=self.input_shape)
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.conv(var_x, 128)
|
||||||
|
var_x = self.blocks.conv(var_x, 256)
|
||||||
|
var_x = self.blocks.conv(var_x, 512)
|
||||||
|
var_x = self.blocks.conv(var_x, 1024)
|
||||||
|
var_x = Dense(self.encoder_dim)(Flatten()(var_x))
|
||||||
|
var_x = Dense(8 * 8 * self.encoder_dim)(var_x)
|
||||||
|
var_x = Reshape((8, 8, self.encoder_dim))(var_x)
|
||||||
|
var_x = self.blocks.upscale(var_x, self.encoder_dim)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def decoder(self):
|
||||||
|
""" DFL H128 Decoder """
|
||||||
|
input_ = Input(shape=(16, 16, self.encoder_dim))
|
||||||
|
var = input_
|
||||||
|
var = self.blocks.upscale(var, self.encoder_dim)
|
||||||
|
var = self.blocks.upscale(var, self.encoder_dim // 2)
|
||||||
|
var = self.blocks.upscale(var, self.encoder_dim // 4)
|
||||||
|
|
||||||
|
# Face
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var)
|
||||||
|
outputs = [var_x]
|
||||||
|
# Mask
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
84
plugins/train/model/iae.py
Normal file
84
plugins/train/model/iae.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Improved autoencoder for faceswap """
|
||||||
|
|
||||||
|
from keras.layers import Concatenate, Conv2D, Dense, Flatten, Input, Reshape
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from ._base import ModelBase, logger
|
||||||
|
|
||||||
|
|
||||||
|
class Model(ModelBase):
|
||||||
|
""" Improved Autoeencoder Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
kwargs["input_shape"] = (64, 64, 3)
|
||||||
|
kwargs["encoder_dim"] = 1024
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def add_networks(self):
|
||||||
|
""" Add the IAE model weights """
|
||||||
|
logger.debug("Adding networks")
|
||||||
|
self.add_network("encoder", None, self.encoder())
|
||||||
|
self.add_network("decoder", None, self.decoder())
|
||||||
|
self.add_network("intermediate", "a", self.intermediate())
|
||||||
|
self.add_network("intermediate", "b", self.intermediate())
|
||||||
|
self.add_network("inter", None, self.intermediate())
|
||||||
|
logger.debug("Added networks")
|
||||||
|
|
||||||
|
def build_autoencoders(self):
|
||||||
|
""" Initialize IAE model """
|
||||||
|
logger.debug("Initializing model")
|
||||||
|
inputs = [Input(shape=self.input_shape, name="face")]
|
||||||
|
if self.config.get("mask_type", "none") != "none":
|
||||||
|
mask_shape = (self.input_shape[:2] + (1, ))
|
||||||
|
inputs.append(Input(shape=mask_shape, name="mask"))
|
||||||
|
|
||||||
|
decoder = self.networks["decoder"].network
|
||||||
|
encoder = self.networks["encoder"].network
|
||||||
|
inter_both = self.networks["inter"].network
|
||||||
|
for side in ("a", "b"):
|
||||||
|
inter_side = self.networks["intermediate_{}".format(side)].network
|
||||||
|
output = decoder(Concatenate()([inter_side(encoder(inputs[0])),
|
||||||
|
inter_both(encoder(inputs[0]))]))
|
||||||
|
|
||||||
|
autoencoder = KerasModel(inputs, output)
|
||||||
|
self.add_predictor(side, autoencoder)
|
||||||
|
logger.debug("Initialized model")
|
||||||
|
|
||||||
|
def encoder(self):
|
||||||
|
""" Encoder Network """
|
||||||
|
input_ = Input(shape=self.input_shape)
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.conv(var_x, 128)
|
||||||
|
var_x = self.blocks.conv(var_x, 266)
|
||||||
|
var_x = self.blocks.conv(var_x, 512)
|
||||||
|
var_x = self.blocks.conv(var_x, 1024)
|
||||||
|
var_x = Flatten()(var_x)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def intermediate(self):
|
||||||
|
""" Intermediate Network """
|
||||||
|
input_ = Input(shape=(None, 4 * 4 * 1024))
|
||||||
|
var_x = input_
|
||||||
|
var_x = Dense(self.encoder_dim)(var_x)
|
||||||
|
var_x = Dense(4 * 4 * int(self.encoder_dim/2))(var_x)
|
||||||
|
var_x = Reshape((4, 4, int(self.encoder_dim/2)))(var_x)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def decoder(self):
|
||||||
|
""" Decoder Network """
|
||||||
|
input_ = Input(shape=(4, 4, self.encoder_dim))
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.upscale(var_x, 512)
|
||||||
|
var_x = self.blocks.upscale(var_x, 256)
|
||||||
|
var_x = self.blocks.upscale(var_x, 128)
|
||||||
|
var_x = self.blocks.upscale(var_x, 64)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var_x)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
83
plugins/train/model/original.py
Normal file
83
plugins/train/model/original.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Original Model
|
||||||
|
Based on the original https://www.reddit.com/r/deepfakes/
|
||||||
|
code sample + contribs """
|
||||||
|
|
||||||
|
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape
|
||||||
|
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from ._base import ModelBase, logger
|
||||||
|
|
||||||
|
|
||||||
|
class Model(ModelBase):
|
||||||
|
""" Original Faceswap Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
|
||||||
|
if "input_shape" not in kwargs:
|
||||||
|
kwargs["input_shape"] = (64, 64, 3)
|
||||||
|
if "encoder_dim" not in kwargs:
|
||||||
|
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def add_networks(self):
|
||||||
|
""" Add the original model weights """
|
||||||
|
logger.debug("Adding networks")
|
||||||
|
self.add_network("decoder", "a", self.decoder())
|
||||||
|
self.add_network("decoder", "b", self.decoder())
|
||||||
|
self.add_network("encoder", None, self.encoder())
|
||||||
|
logger.debug("Added networks")
|
||||||
|
|
||||||
|
def build_autoencoders(self):
|
||||||
|
""" Initialize original model """
|
||||||
|
logger.debug("Initializing model")
|
||||||
|
inputs = [Input(shape=self.input_shape, name="face")]
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
mask_shape = (self.input_shape[:2] + (1, ))
|
||||||
|
inputs.append(Input(shape=mask_shape, name="mask"))
|
||||||
|
|
||||||
|
for side in ("a", "b"):
|
||||||
|
logger.debug("Adding Autoencoder. Side: %s", side)
|
||||||
|
decoder = self.networks["decoder_{}".format(side)].network
|
||||||
|
output = decoder(self.networks["encoder"].network(inputs[0]))
|
||||||
|
autoencoder = KerasModel(inputs, output)
|
||||||
|
self.add_predictor(side, autoencoder)
|
||||||
|
logger.debug("Initialized model")
|
||||||
|
|
||||||
|
def encoder(self):
|
||||||
|
""" Encoder Network """
|
||||||
|
input_ = Input(shape=self.input_shape)
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.conv(var_x, 128)
|
||||||
|
var_x = self.blocks.conv(var_x, 256)
|
||||||
|
var_x = self.blocks.conv(var_x, 512)
|
||||||
|
if not self.config.get("lowmem", False):
|
||||||
|
var_x = self.blocks.conv(var_x, 1024)
|
||||||
|
var_x = Dense(self.encoder_dim)(Flatten()(var_x))
|
||||||
|
var_x = Dense(4 * 4 * 1024)(var_x)
|
||||||
|
var_x = Reshape((4, 4, 1024))(var_x)
|
||||||
|
var_x = self.blocks.upscale(var_x, 512)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def decoder(self):
|
||||||
|
""" Decoder Network """
|
||||||
|
input_ = Input(shape=(8, 8, 512))
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.upscale(var_x, 256)
|
||||||
|
var_x = self.blocks.upscale(var_x, 128)
|
||||||
|
var_x = self.blocks.upscale(var_x, 64)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = input_
|
||||||
|
var_y = self.blocks.upscale(var_y, 256)
|
||||||
|
var_y = self.blocks.upscale(var_y, 128)
|
||||||
|
var_y = self.blocks.upscale(var_y, 64)
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
130
plugins/train/model/unbalanced.py
Normal file
130
plugins/train/model/unbalanced.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Unbalanced Model
|
||||||
|
Based on the original https://www.reddit.com/r/deepfakes/
|
||||||
|
code sample + contribs """
|
||||||
|
|
||||||
|
from keras.initializers import RandomNormal
|
||||||
|
from keras.layers import Conv2D, Dense, Flatten, Input, Reshape, SpatialDropout2D
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from .original import logger, Model as OriginalModel
|
||||||
|
|
||||||
|
|
||||||
|
class Model(OriginalModel):
|
||||||
|
""" Unbalanced Faceswap Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
|
||||||
|
self.lowmem = self.config.get("lowmem", False)
|
||||||
|
kwargs["input_shape"] = (self.config["input_size"], self.config["input_size"], 3)
|
||||||
|
kwargs["encoder_dim"] = 512 if self.lowmem else self.config["nodes"]
|
||||||
|
self.kernel_initializer = RandomNormal(0, 0.02)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def add_networks(self):
|
||||||
|
""" Add the original model weights """
|
||||||
|
logger.debug("Adding networks")
|
||||||
|
self.add_network("decoder", "a", self.decoder_a())
|
||||||
|
self.add_network("decoder", "b", self.decoder_b())
|
||||||
|
self.add_network("encoder", None, self.encoder())
|
||||||
|
logger.debug("Added networks")
|
||||||
|
|
||||||
|
def encoder(self):
|
||||||
|
""" Unbalanced Encoder """
|
||||||
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
||||||
|
encoder_complexity = 128 if self.lowmem else self.config["complexity_encoder"]
|
||||||
|
dense_dim = 384 if self.lowmem else 512
|
||||||
|
dense_shape = self.input_shape[0] // 16
|
||||||
|
input_ = Input(shape=self.input_shape)
|
||||||
|
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.conv(var_x, encoder_complexity, use_instance_norm=True, **kwargs)
|
||||||
|
var_x = self.blocks.conv(var_x, encoder_complexity * 2, use_instance_norm=True, **kwargs)
|
||||||
|
var_x = self.blocks.conv(var_x, encoder_complexity * 4, **kwargs)
|
||||||
|
var_x = self.blocks.conv(var_x, encoder_complexity * 6, **kwargs)
|
||||||
|
var_x = self.blocks.conv(var_x, encoder_complexity * 8, **kwargs)
|
||||||
|
var_x = Dense(self.encoder_dim,
|
||||||
|
kernel_initializer=self.kernel_initializer)(Flatten()(var_x))
|
||||||
|
var_x = Dense(dense_shape * dense_shape * dense_dim,
|
||||||
|
kernel_initializer=self.kernel_initializer)(var_x)
|
||||||
|
var_x = Reshape((dense_shape, dense_shape, dense_dim))(var_x)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def decoder_a(self):
|
||||||
|
""" Decoder for side A """
|
||||||
|
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
|
||||||
|
decoder_complexity = 320 if self.lowmem else self.config["complexity_decoder_a"]
|
||||||
|
dense_dim = 384 if self.lowmem else 512
|
||||||
|
decoder_shape = self.input_shape[0] // 16
|
||||||
|
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
|
||||||
|
|
||||||
|
var_x = input_
|
||||||
|
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
||||||
|
var_x = SpatialDropout2D(0.25)(var_x)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
||||||
|
if self.lowmem:
|
||||||
|
var_x = SpatialDropout2D(0.15)(var_x)
|
||||||
|
else:
|
||||||
|
var_x = SpatialDropout2D(0.25)(var_x)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = input_
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
||||||
|
|
||||||
|
def decoder_b(self):
|
||||||
|
""" Decoder for side B """
|
||||||
|
kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer)
|
||||||
|
dense_dim = 384 if self.lowmem else self.config["complexity_decoder_b"]
|
||||||
|
decoder_complexity = 384 if self.lowmem else 512
|
||||||
|
decoder_shape = self.input_shape[0] // 16
|
||||||
|
input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim))
|
||||||
|
|
||||||
|
var_x = input_
|
||||||
|
if self.lowmem:
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 8, **kwargs)
|
||||||
|
else:
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity,
|
||||||
|
res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, decoder_complexity,
|
||||||
|
kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity,
|
||||||
|
res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, decoder_complexity,
|
||||||
|
kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 2,
|
||||||
|
res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, decoder_complexity // 2,
|
||||||
|
kernel_initializer=self.kernel_initializer)
|
||||||
|
var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = input_
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
||||||
|
if not self.lowmem:
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity)
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 2)
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 4)
|
||||||
|
if self.lowmem:
|
||||||
|
var_y = self.blocks.upscale(var_y, decoder_complexity // 8)
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
83
plugins/train/model/villain.py
Normal file
83
plugins/train/model/villain.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Original - VillainGuy model
|
||||||
|
Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs
|
||||||
|
Adapted from a model by VillainGuy (https://github.com/VillainGuy) """
|
||||||
|
|
||||||
|
from keras.initializers import RandomNormal
|
||||||
|
from keras.layers import add, Conv2D, Dense, Flatten, Input, Reshape
|
||||||
|
from keras.models import Model as KerasModel
|
||||||
|
|
||||||
|
from lib.model.layers import PixelShuffler
|
||||||
|
from .original import logger, Model as OriginalModel
|
||||||
|
|
||||||
|
|
||||||
|
class Model(OriginalModel):
|
||||||
|
""" Villain Faceswap Model """
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logger.debug("Initializing %s: (args: %s, kwargs: %s",
|
||||||
|
self.__class__.__name__, args, kwargs)
|
||||||
|
|
||||||
|
kwargs["input_shape"] = (128, 128, 3)
|
||||||
|
kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024
|
||||||
|
self.kernel_initializer = RandomNormal(0, 0.02)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def encoder(self):
|
||||||
|
""" Encoder Network """
|
||||||
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
||||||
|
input_ = Input(shape=self.input_shape)
|
||||||
|
in_conv_filters = self.input_shape[0]
|
||||||
|
if self.input_shape[0] > 128:
|
||||||
|
in_conv_filters = 128 + (self.input_shape[0] - 128) // 4
|
||||||
|
dense_shape = self.input_shape[0] // 16
|
||||||
|
|
||||||
|
var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs)
|
||||||
|
tmp_x = var_x
|
||||||
|
res_cycles = 8 if self.config.get("lowmem", False) else 16
|
||||||
|
for _ in range(res_cycles):
|
||||||
|
nn_x = self.blocks.res_block(var_x, 128, **kwargs)
|
||||||
|
var_x = nn_x
|
||||||
|
# consider adding scale before this layer to scale the residual chain
|
||||||
|
var_x = add([var_x, tmp_x])
|
||||||
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
||||||
|
var_x = PixelShuffler()(var_x)
|
||||||
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
||||||
|
var_x = PixelShuffler()(var_x)
|
||||||
|
var_x = self.blocks.conv(var_x, 128, **kwargs)
|
||||||
|
var_x = self.blocks.conv_sep(var_x, 256, **kwargs)
|
||||||
|
var_x = self.blocks.conv(var_x, 512, **kwargs)
|
||||||
|
if not self.config.get("lowmem", False):
|
||||||
|
var_x = self.blocks.conv_sep(var_x, 1024, **kwargs)
|
||||||
|
|
||||||
|
var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x))
|
||||||
|
var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x)
|
||||||
|
var_x = Reshape((dense_shape, dense_shape, 1024))(var_x)
|
||||||
|
var_x = self.blocks.upscale(var_x, 512, **kwargs)
|
||||||
|
return KerasModel(input_, var_x)
|
||||||
|
|
||||||
|
def decoder(self):
|
||||||
|
""" Decoder Network """
|
||||||
|
kwargs = dict(kernel_initializer=self.kernel_initializer)
|
||||||
|
decoder_shape = self.input_shape[0] // 8
|
||||||
|
input_ = Input(shape=(decoder_shape, decoder_shape, 512))
|
||||||
|
|
||||||
|
var_x = input_
|
||||||
|
var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, 512, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, 256, **kwargs)
|
||||||
|
var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs)
|
||||||
|
var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs)
|
||||||
|
var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x)
|
||||||
|
outputs = [var_x]
|
||||||
|
|
||||||
|
if self.config.get("mask_type", None):
|
||||||
|
var_y = input_
|
||||||
|
var_y = self.blocks.upscale(var_y, 512)
|
||||||
|
var_y = self.blocks.upscale(var_y, 256)
|
||||||
|
var_y = self.blocks.upscale(var_y, self.input_shape[0])
|
||||||
|
var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y)
|
||||||
|
outputs.append(var_y)
|
||||||
|
return KerasModel(input_, outputs=outputs)
|
0
plugins/train/trainer/__init__.py
Normal file
0
plugins/train/trainer/__init__.py
Normal file
576
plugins/train/trainer/_base.py
Normal file
576
plugins/train/trainer/_base.py
Normal file
|
@ -0,0 +1,576 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
""" Base Trainer Class for Faceswap
|
||||||
|
|
||||||
|
Trainers should be inherited from this class.
|
||||||
|
|
||||||
|
A training_opts dictionary can be set in the corresponding model.
|
||||||
|
Accepted values:
|
||||||
|
alignments: dict containing paths to alignments files for keys 'a' and 'b'
|
||||||
|
preview_scaling: How much to scale the preview out by
|
||||||
|
training_size: Size of the training images
|
||||||
|
coverage_ratio: Ratio of face to be cropped out for training
|
||||||
|
mask_type: Type of mask to use. See lib.model.masks for valid mask names.
|
||||||
|
Set to None for not used
|
||||||
|
no_logs: Disable tensorboard logging
|
||||||
|
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
|
||||||
|
no_flip: Don't perform a random flip on the image
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow import keras as tf_keras
|
||||||
|
|
||||||
|
from lib.alignments import Alignments
|
||||||
|
from lib.faces_detect import DetectedFace
|
||||||
|
from lib.training_data import TrainingDataGenerator, stack_images
|
||||||
|
from lib.utils import get_folder, get_image_paths
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerBase():
|
||||||
|
""" Base Trainer """
|
||||||
|
|
||||||
|
def __init__(self, model, images, batch_size):
|
||||||
|
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
|
||||||
|
self.__class__.__name__, model, batch_size)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.model = model
|
||||||
|
self.model.state.add_session_batchsize(batch_size)
|
||||||
|
self.images = images
|
||||||
|
|
||||||
|
self.process_training_opts()
|
||||||
|
|
||||||
|
self.batchers = {side: Batcher(side,
|
||||||
|
images[side],
|
||||||
|
self.model,
|
||||||
|
self.use_mask,
|
||||||
|
batch_size)
|
||||||
|
for side in images.keys()}
|
||||||
|
|
||||||
|
self.tensorboard = self.set_tensorboard()
|
||||||
|
self.samples = Samples(self.model,
|
||||||
|
self.use_mask,
|
||||||
|
self.model.training_opts["coverage_ratio"],
|
||||||
|
self.model.training_opts["preview_scaling"])
|
||||||
|
self.timelapse = Timelapse(self.model,
|
||||||
|
self.use_mask,
|
||||||
|
self.model.training_opts["coverage_ratio"],
|
||||||
|
self.batchers)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestamp(self):
|
||||||
|
""" Standardised timestamp for loss reporting """
|
||||||
|
return time.strftime("%H:%M:%S")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def landmarks_required(self):
|
||||||
|
""" Return True if Landmarks are required """
|
||||||
|
opts = self.model.training_opts
|
||||||
|
retval = bool(opts.get("mask_type", None) or opts["warp_to_landmarks"])
|
||||||
|
logger.debug(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_mask(self):
|
||||||
|
""" Return True if a mask is requested """
|
||||||
|
retval = bool(self.model.training_opts.get("mask_type", None))
|
||||||
|
logger.debug(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def process_training_opts(self):
|
||||||
|
""" Override for processing model specific training options """
|
||||||
|
logger.debug(self.model.training_opts)
|
||||||
|
if self.landmarks_required:
|
||||||
|
landmarks = Landmarks(self.model.training_opts).landmarks
|
||||||
|
self.model.training_opts["landmarks"] = landmarks
|
||||||
|
|
||||||
|
def set_tensorboard(self):
|
||||||
|
""" Set up tensorboard callback """
|
||||||
|
if self.model.training_opts["no_logs"]:
|
||||||
|
logger.verbose("TensorBoard logging disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Enabling TensorBoard Logging")
|
||||||
|
tensorboard = dict()
|
||||||
|
for side in self.images.keys():
|
||||||
|
logger.debug("Setting up TensorBoard Logging. Side: %s", side)
|
||||||
|
log_dir = os.path.join(str(self.model.model_dir),
|
||||||
|
"{}_logs".format(self.model.name),
|
||||||
|
side,
|
||||||
|
"session_{}".format(self.model.state.session_id))
|
||||||
|
tbs = tf_keras.callbacks.TensorBoard(log_dir=log_dir,
|
||||||
|
histogram_freq=0, # Must be 0 or hangs
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
write_graph=True,
|
||||||
|
write_grads=True)
|
||||||
|
tbs.set_model(self.model.predictors[side])
|
||||||
|
tensorboard[side] = tbs
|
||||||
|
logger.info("Enabled TensorBoard Logging")
|
||||||
|
return tensorboard
|
||||||
|
|
||||||
|
def print_loss(self, loss):
|
||||||
|
""" Override for specific model loss formatting """
|
||||||
|
output = list()
|
||||||
|
for side in sorted(list(loss.keys())):
|
||||||
|
display = ", ".join(["{}_{}: {:.5f}".format(self.model.state.loss_names[side][idx],
|
||||||
|
side.capitalize(),
|
||||||
|
this_loss)
|
||||||
|
for idx, this_loss in enumerate(loss[side])])
|
||||||
|
output.append(display)
|
||||||
|
print("[{}] [#{:05d}] {}, {}".format(
|
||||||
|
self.timestamp, self.model.iterations, output[0], output[1]), end='\r')
|
||||||
|
|
||||||
|
def train_one_step(self, viewer, timelapse_kwargs):
|
||||||
|
""" Train a batch """
|
||||||
|
logger.trace("Training one step: (iteration: %s)", self.model.iterations)
|
||||||
|
is_preview_iteration = False if viewer is None else True
|
||||||
|
loss = dict()
|
||||||
|
for side, batcher in self.batchers.items():
|
||||||
|
loss[side] = batcher.train_one_batch(is_preview_iteration)
|
||||||
|
if not is_preview_iteration:
|
||||||
|
continue
|
||||||
|
self.samples.images[side] = batcher.compile_sample(self.batch_size)
|
||||||
|
if timelapse_kwargs:
|
||||||
|
self.timelapse.get_sample(side, timelapse_kwargs)
|
||||||
|
|
||||||
|
self.model.state.increment_iterations()
|
||||||
|
|
||||||
|
for side, side_loss in loss.items():
|
||||||
|
self.store_history(side, side_loss)
|
||||||
|
self.log_tensorboard(side, side_loss)
|
||||||
|
self.print_loss(loss)
|
||||||
|
|
||||||
|
if viewer is not None:
|
||||||
|
viewer(self.samples.show_sample(),
|
||||||
|
"Training - 'S': Save Now. 'ENTER': Save and Quit")
|
||||||
|
|
||||||
|
if timelapse_kwargs is not None:
|
||||||
|
self.timelapse.output_timelapse()
|
||||||
|
|
||||||
|
def store_history(self, side, loss):
|
||||||
|
""" Store the history of this step """
|
||||||
|
logger.trace("Updating loss history: '%s'", side)
|
||||||
|
self.model.history[side].append(loss[0]) # Either only loss or total loss
|
||||||
|
logger.trace("Updated loss history: '%s'", side)
|
||||||
|
|
||||||
|
def log_tensorboard(self, side, loss):
|
||||||
|
""" Log loss to TensorBoard log """
|
||||||
|
if not self.tensorboard:
|
||||||
|
return
|
||||||
|
logger.trace("Updating TensorBoard log: '%s'", side)
|
||||||
|
logs = {log[0]: log[1]
|
||||||
|
for log in zip(self.model.state.loss_names[side], loss)}
|
||||||
|
self.tensorboard[side].on_batch_end(self.model.state.iterations, logs)
|
||||||
|
logger.trace("Updated TensorBoard log: '%s'", side)
|
||||||
|
|
||||||
|
def clear_tensorboard(self):
|
||||||
|
""" Indicate training end to Tensorboard """
|
||||||
|
if not self.tensorboard:
|
||||||
|
return
|
||||||
|
for side, tensorboard in self.tensorboard.items():
|
||||||
|
logger.debug("Ending Tensorboard. Side: '%s'", side)
|
||||||
|
tensorboard.on_train_end(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Batcher():
|
||||||
|
""" Batch images from a single side """
|
||||||
|
def __init__(self, side, images, model, use_mask, batch_size):
|
||||||
|
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)",
|
||||||
|
self.__class__.__name__, side, len(images), batch_size)
|
||||||
|
self.model = model
|
||||||
|
self.use_mask = use_mask
|
||||||
|
self.side = side
|
||||||
|
self.target = None
|
||||||
|
self.samples = None
|
||||||
|
self.mask = None
|
||||||
|
|
||||||
|
self.feed = self.load_generator().minibatch_ab(images, batch_size, self.side)
|
||||||
|
self.timelapse_feed = None
|
||||||
|
|
||||||
|
def load_generator(self):
|
||||||
|
""" Pass arguments to TrainingDataGenerator and return object """
|
||||||
|
logger.debug("Loading generator: %s", self.side)
|
||||||
|
input_size = self.model.input_shape[0]
|
||||||
|
output_size = self.model.output_shape[0]
|
||||||
|
logger.debug("input_size: %s, output_size: %s", input_size, output_size)
|
||||||
|
generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts)
|
||||||
|
return generator
|
||||||
|
|
||||||
|
def train_one_batch(self, is_preview_iteration):
|
||||||
|
""" Train a batch """
|
||||||
|
logger.trace("Training one step: (side: %s)", self.side)
|
||||||
|
batch = self.get_next(is_preview_iteration)
|
||||||
|
loss = self.model.predictors[self.side].train_on_batch(*batch)
|
||||||
|
loss = loss if isinstance(loss, list) else [loss]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_next(self, is_preview_iteration):
|
||||||
|
""" Return the next batch from the generator
|
||||||
|
Items should come out as: (warped, target [, mask]) """
|
||||||
|
batch = next(self.feed)
|
||||||
|
self.samples = batch[0] if is_preview_iteration else None
|
||||||
|
batch = batch[1:] # Remove full size samples from batch
|
||||||
|
if self.use_mask:
|
||||||
|
batch = self.compile_mask(batch)
|
||||||
|
self.target = batch[1] if is_preview_iteration else None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def compile_mask(self, batch):
|
||||||
|
""" Compile the mask into training data """
|
||||||
|
logger.trace("Compiling Mask: (side: '%s')", self.side)
|
||||||
|
mask = batch[-1]
|
||||||
|
retval = list()
|
||||||
|
for idx in range(len(batch) - 1):
|
||||||
|
image = batch[idx]
|
||||||
|
retval.append([image, mask])
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def compile_sample(self, batch_size, samples=None, images=None):
|
||||||
|
""" Training samples to display in the viewer """
|
||||||
|
num_images = self.model.training_opts.get("preview_images", 14)
|
||||||
|
num_images = min(batch_size, num_images)
|
||||||
|
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
|
||||||
|
images = images if images is not None else self.target
|
||||||
|
samples = [samples[0:num_images]] if samples is not None else [self.samples[0:num_images]]
|
||||||
|
if self.use_mask:
|
||||||
|
retval = [tgt[0:num_images] for tgt in images]
|
||||||
|
else:
|
||||||
|
retval = [images[0:num_images]]
|
||||||
|
retval = samples + retval
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def compile_timelapse_sample(self):
|
||||||
|
""" Timelapse samples """
|
||||||
|
batch = next(self.timelapse_feed)
|
||||||
|
samples = batch[0]
|
||||||
|
batch = batch[1:] # Remove full size samples from batch
|
||||||
|
batchsize = len(samples)
|
||||||
|
if self.use_mask:
|
||||||
|
batch = self.compile_mask(batch)
|
||||||
|
images = batch[1]
|
||||||
|
sample = self.compile_sample(batchsize, samples=samples, images=images)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_timelapse_feed(self, images, batchsize):
|
||||||
|
""" Set the timelapse dictionary """
|
||||||
|
logger.debug("Setting timelapse feed: (side: '%s', input_images: '%s', batchsize: %s)",
|
||||||
|
self.side, images, batchsize)
|
||||||
|
self.timelapse_feed = self.load_generator().minibatch_ab(images[:batchsize],
|
||||||
|
batchsize, self.side,
|
||||||
|
do_shuffle=False,
|
||||||
|
is_timelapse=True)
|
||||||
|
logger.debug("Set timelapse feed")
|
||||||
|
|
||||||
|
|
||||||
|
class Samples():
|
||||||
|
""" Display samples for preview and timelapse """
|
||||||
|
def __init__(self, model, use_mask, coverage_ratio, scaling=1.0):
|
||||||
|
logger.debug("Initializing %s: model: '%s', use_mask: %s, coverage_ratio: %s)",
|
||||||
|
self.__class__.__name__, model, use_mask, coverage_ratio)
|
||||||
|
self.model = model
|
||||||
|
self.use_mask = use_mask
|
||||||
|
self.images = dict()
|
||||||
|
self.coverage_ratio = coverage_ratio
|
||||||
|
self.scaling = scaling
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def show_sample(self):
|
||||||
|
""" Display preview data """
|
||||||
|
logger.debug("Showing sample")
|
||||||
|
feeds = dict()
|
||||||
|
figures = dict()
|
||||||
|
headers = dict()
|
||||||
|
for side, samples in self.images.items():
|
||||||
|
faces = samples[1]
|
||||||
|
if self.model.input_shape[0] / faces.shape[1] != 1.0:
|
||||||
|
feeds[side] = self.resize_sample(side, faces, self.model.input_shape[0])
|
||||||
|
feeds[side] = feeds[side].reshape((-1, ) + self.model.input_shape)
|
||||||
|
else:
|
||||||
|
feeds[side] = faces
|
||||||
|
if self.use_mask:
|
||||||
|
mask = samples[-1]
|
||||||
|
feeds[side] = [feeds[side], mask]
|
||||||
|
|
||||||
|
preds = self.get_predictions(feeds["a"], feeds["b"])
|
||||||
|
|
||||||
|
for side, samples in self.images.items():
|
||||||
|
other_side = "a" if side == "b" else "b"
|
||||||
|
predictions = [preds["{}_{}".format(side, side)],
|
||||||
|
preds["{}_{}".format(other_side, side)]]
|
||||||
|
display = self.to_full_frame(side, samples, predictions)
|
||||||
|
headers[side] = self.get_headers(side, other_side, display[0].shape[1])
|
||||||
|
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
|
||||||
|
if self.images[side][0].shape[0] % 2 == 1:
|
||||||
|
figures[side] = np.concatenate([figures[side],
|
||||||
|
np.expand_dims(figures[side][0], 0)])
|
||||||
|
|
||||||
|
width = 4
|
||||||
|
side_cols = width // 2
|
||||||
|
if side_cols != 1:
|
||||||
|
headers = self.duplicate_headers(headers, side_cols)
|
||||||
|
|
||||||
|
header = np.concatenate([headers["a"], headers["b"]], axis=1)
|
||||||
|
figure = np.concatenate([figures["a"], figures["b"]], axis=0)
|
||||||
|
height = int(figure.shape[0] / width)
|
||||||
|
figure = figure.reshape((width, height) + figure.shape[1:])
|
||||||
|
figure = stack_images(figure)
|
||||||
|
figure = np.vstack((header, figure))
|
||||||
|
|
||||||
|
logger.debug("Compiled sample")
|
||||||
|
return np.clip(figure * 255, 0, 255).astype('uint8')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resize_sample(side, sample, target_size):
|
||||||
|
""" Resize samples where predictor expects different shape from processed image """
|
||||||
|
scale = target_size / sample.shape[1]
|
||||||
|
if scale == 1.0:
|
||||||
|
return sample
|
||||||
|
logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)",
|
||||||
|
side, sample.shape, target_size, scale)
|
||||||
|
interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint: disable=no-member
|
||||||
|
retval = np.array([cv2.resize(img, # pylint: disable=no-member
|
||||||
|
(target_size, target_size),
|
||||||
|
interpn)
|
||||||
|
for img in sample])
|
||||||
|
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def get_predictions(self, feed_a, feed_b):
|
||||||
|
""" Return the sample predictions from the model """
|
||||||
|
logger.debug("Getting Predictions")
|
||||||
|
preds = dict()
|
||||||
|
preds["a_a"] = self.model.predictors["a"].predict(feed_a)
|
||||||
|
preds["b_a"] = self.model.predictors["b"].predict(feed_a)
|
||||||
|
preds["a_b"] = self.model.predictors["a"].predict(feed_b)
|
||||||
|
preds["b_b"] = self.model.predictors["b"].predict(feed_b)
|
||||||
|
|
||||||
|
# Get the returned image from predictors that emit multiple items
|
||||||
|
if not isinstance(preds["a_a"], np.ndarray):
|
||||||
|
for key, val in preds.items():
|
||||||
|
preds[key] = val[0]
|
||||||
|
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def to_full_frame(self, side, samples, predictions):
|
||||||
|
""" Patch the images into the full frame """
|
||||||
|
logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)",
|
||||||
|
side, len(samples), [pred.shape for pred in predictions])
|
||||||
|
full, faces = samples[:2]
|
||||||
|
images = [faces] + predictions
|
||||||
|
full_size = full.shape[1]
|
||||||
|
target_size = int(full_size * self.coverage_ratio)
|
||||||
|
if target_size != full_size:
|
||||||
|
frame = self.frame_overlay(full, target_size, (0, 0, 255))
|
||||||
|
|
||||||
|
if self.use_mask:
|
||||||
|
images = self.compile_masked(images, samples[-1])
|
||||||
|
images = [self.resize_sample(side, image, target_size) for image in images]
|
||||||
|
if target_size != full_size:
|
||||||
|
images = [self.overlay_foreground(frame, image) for image in images]
|
||||||
|
if self.scaling != 1.0:
|
||||||
|
new_size = int(full_size * self.scaling)
|
||||||
|
images = [self.resize_sample(side, image, new_size) for image in images]
|
||||||
|
return images
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def frame_overlay(images, target_size, color):
|
||||||
|
""" Add roi frame to a backfround image """
|
||||||
|
logger.debug("full_size: %s, target_size: %s, color: %s",
|
||||||
|
images.shape[1], target_size, color)
|
||||||
|
new_images = list()
|
||||||
|
full_size = images.shape[1]
|
||||||
|
padding = (full_size - target_size) // 2
|
||||||
|
length = target_size // 4
|
||||||
|
t_l, b_r = (padding, full_size - padding)
|
||||||
|
for img in images:
|
||||||
|
cv2.rectangle(img, # pylint: disable=no-member
|
||||||
|
(t_l, t_l),
|
||||||
|
(t_l + length, t_l + length),
|
||||||
|
color,
|
||||||
|
3)
|
||||||
|
cv2.rectangle(img, # pylint: disable=no-member
|
||||||
|
(b_r, t_l),
|
||||||
|
(b_r - length, t_l + length),
|
||||||
|
color,
|
||||||
|
3)
|
||||||
|
cv2.rectangle(img, # pylint: disable=no-member
|
||||||
|
(b_r, b_r),
|
||||||
|
(b_r - length,
|
||||||
|
b_r - length),
|
||||||
|
color,
|
||||||
|
3)
|
||||||
|
cv2.rectangle(img, # pylint: disable=no-member
|
||||||
|
(t_l, b_r),
|
||||||
|
(t_l + length, b_r - length),
|
||||||
|
color,
|
||||||
|
3)
|
||||||
|
new_images.append(img)
|
||||||
|
retval = np.array(new_images)
|
||||||
|
logger.debug("Overlayed background. Shape: %s", retval.shape)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compile_masked(faces, masks):
|
||||||
|
""" Add the mask to the faces for masked preview """
|
||||||
|
retval = list()
|
||||||
|
masks3 = np.tile(1 - np.rint(masks), 3)
|
||||||
|
for mask in masks3:
|
||||||
|
mask[np.where((mask == [1., 1., 1.]).all(axis=2))] = [0., 0., 1.]
|
||||||
|
for previews in faces:
|
||||||
|
images = np.array([cv2.addWeighted(img, 1.0, # pylint: disable=no-member
|
||||||
|
masks3[idx], 0.3,
|
||||||
|
0)
|
||||||
|
for idx, img in enumerate(previews)])
|
||||||
|
retval.append(images)
|
||||||
|
logger.debug("masked shapes: %s", [faces.shape for faces in retval])
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def overlay_foreground(backgrounds, foregrounds):
|
||||||
|
""" Overlay the training images into the center of the background """
|
||||||
|
offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2
|
||||||
|
new_images = list()
|
||||||
|
for idx, img in enumerate(backgrounds):
|
||||||
|
img[offset:offset + foregrounds[idx].shape[0],
|
||||||
|
offset:offset + foregrounds[idx].shape[1]] = foregrounds[idx]
|
||||||
|
new_images.append(img)
|
||||||
|
retval = np.array(new_images)
|
||||||
|
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def get_headers(self, side, other_side, width):
|
||||||
|
""" Set headers for images """
|
||||||
|
logger.debug("side: '%s', other_side: '%s', width: %s",
|
||||||
|
side, other_side, width)
|
||||||
|
side = side.upper()
|
||||||
|
other_side = other_side.upper()
|
||||||
|
height = int(64 * self.scaling)
|
||||||
|
total_width = width * 3
|
||||||
|
logger.debug("height: %s, total_width: %s", height, total_width)
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
|
||||||
|
texts = ["Target {}".format(side),
|
||||||
|
"{} > {}".format(side, side),
|
||||||
|
"{} > {}".format(side, other_side)]
|
||||||
|
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
|
||||||
|
font,
|
||||||
|
self.scaling,
|
||||||
|
1)[0]
|
||||||
|
for idx in range(len(texts))]
|
||||||
|
text_y = int((height + text_sizes[0][1]) / 2)
|
||||||
|
text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx
|
||||||
|
for idx in range(len(texts))]
|
||||||
|
logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s",
|
||||||
|
texts, text_sizes, text_x, text_y)
|
||||||
|
header_box = np.ones((height, total_width, 3), np.float32)
|
||||||
|
for idx, text in enumerate(texts):
|
||||||
|
cv2.putText(header_box, # pylint: disable=no-member
|
||||||
|
text,
|
||||||
|
(text_x[idx], text_y),
|
||||||
|
font,
|
||||||
|
self.scaling,
|
||||||
|
(0, 0, 0),
|
||||||
|
1,
|
||||||
|
lineType=cv2.LINE_AA) # pylint: disable=no-member
|
||||||
|
logger.debug("header_box.shape: %s", header_box.shape)
|
||||||
|
return header_box
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def duplicate_headers(headers, columns):
|
||||||
|
""" Duplicate headers for the number of columns displayed """
|
||||||
|
for side, header in headers.items():
|
||||||
|
duped = tuple([header for _ in range(columns)])
|
||||||
|
headers[side] = np.concatenate(duped, axis=1)
|
||||||
|
logger.debug("side: %s header.shape: %s", side, header.shape)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
class Timelapse():
|
||||||
|
""" Create the timelapse """
|
||||||
|
def __init__(self, model, use_mask, coverage_ratio, batchers):
|
||||||
|
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
|
||||||
|
"batchers: '%s')", self.__class__.__name__, model, use_mask,
|
||||||
|
coverage_ratio, batchers)
|
||||||
|
self.samples = Samples(model, use_mask, coverage_ratio)
|
||||||
|
self.model = model
|
||||||
|
self.batchers = batchers
|
||||||
|
self.output_file = None
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def get_sample(self, side, timelapse_kwargs):
|
||||||
|
""" Perform timelapse """
|
||||||
|
logger.debug("Getting timelapse samples: '%s'", side)
|
||||||
|
if not self.output_file:
|
||||||
|
self.setup(**timelapse_kwargs)
|
||||||
|
self.samples.images[side] = self.batchers[side].compile_timelapse_sample()
|
||||||
|
logger.debug("Got timelapse samples: '%s' - %s", side, len(self.samples.images[side]))
|
||||||
|
|
||||||
|
def setup(self, input_a=None, input_b=None, output=None):
|
||||||
|
""" Set the timelapse output folder """
|
||||||
|
logger.debug("Setting up timelapse")
|
||||||
|
if output is None:
|
||||||
|
output = str(get_folder(os.path.join(str(self.model.model_dir),
|
||||||
|
"{}_timelapse".format(self.model.name))))
|
||||||
|
self.output_file = str(output)
|
||||||
|
logger.debug("Timelapse output set to '%s'", self.output_file)
|
||||||
|
|
||||||
|
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
|
||||||
|
batchsize = min(len(images["a"]),
|
||||||
|
len(images["b"]),
|
||||||
|
self.model.training_opts.get("preview_images", 14))
|
||||||
|
for side, image_files in images.items():
|
||||||
|
self.batchers[side].set_timelapse_feed(image_files, batchsize)
|
||||||
|
logger.debug("Set up timelapse")
|
||||||
|
|
||||||
|
def output_timelapse(self):
|
||||||
|
""" Set the timelapse dictionary """
|
||||||
|
logger.debug("Ouputting timelapse")
|
||||||
|
image = self.samples.show_sample()
|
||||||
|
filename = os.path.join(self.output_file, str(int(time.time())) + ".jpg")
|
||||||
|
|
||||||
|
cv2.imwrite(filename, image) # pylint: disable=no-member
|
||||||
|
logger.debug("Created timelapse: '%s'", filename)
|
||||||
|
|
||||||
|
|
||||||
|
class Landmarks():
|
||||||
|
""" Set Landmarks for training into the model's training options"""
|
||||||
|
def __init__(self, training_opts):
|
||||||
|
logger.debug("Initializing %s: (training_opts: '%s')",
|
||||||
|
self.__class__.__name__, training_opts)
|
||||||
|
self.size = training_opts.get("training_size", 256)
|
||||||
|
self.paths = training_opts["alignments"]
|
||||||
|
self.landmarks = self.get_alignments()
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def get_alignments(self):
|
||||||
|
""" Obtain the landmarks for each faceset """
|
||||||
|
landmarks = dict()
|
||||||
|
for side, fullpath in self.paths.items():
|
||||||
|
path, filename = os.path.split(fullpath)
|
||||||
|
filename, extension = os.path.splitext(filename)
|
||||||
|
serializer = extension[1:]
|
||||||
|
alignments = Alignments(
|
||||||
|
path,
|
||||||
|
filename=filename,
|
||||||
|
serializer=serializer)
|
||||||
|
landmarks[side] = self.transform_landmarks(alignments)
|
||||||
|
return landmarks
|
||||||
|
|
||||||
|
def transform_landmarks(self, alignments):
|
||||||
|
""" For each face transform landmarks and return """
|
||||||
|
landmarks = dict()
|
||||||
|
for _, faces, _, _ in alignments.yield_faces():
|
||||||
|
for face in faces:
|
||||||
|
detected_face = DetectedFace()
|
||||||
|
detected_face.from_alignment(face)
|
||||||
|
detected_face.load_aligned(None, size=self.size, align_eyes=False)
|
||||||
|
landmarks[detected_face.hash] = detected_face.aligned_landmarks
|
||||||
|
return landmarks
|
4
plugins/train/trainer/original.py
Normal file
4
plugins/train/trainer/original.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Original Trainer """
|
||||||
|
|
||||||
|
from ._base import TrainerBase as Trainer
|
|
@ -15,7 +15,6 @@ from lib.faces_detect import DetectedFace
|
||||||
from lib.multithreading import BackgroundGenerator, SpawnProcess
|
from lib.multithreading import BackgroundGenerator, SpawnProcess
|
||||||
from lib.queue_manager import queue_manager
|
from lib.queue_manager import queue_manager
|
||||||
from lib.utils import get_folder, get_image_paths, hash_image_file
|
from lib.utils import get_folder, get_image_paths, hash_image_file
|
||||||
|
|
||||||
from plugins.plugin_loader import PluginLoader
|
from plugins.plugin_loader import PluginLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -43,7 +42,7 @@ class Convert():
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
""" Original & LowMem models go with Adjust or Masked converter
|
""" Original & LowMem models go with converter
|
||||||
|
|
||||||
Note: GAN prediction outputs a mask + an image, while other
|
Note: GAN prediction outputs a mask + an image, while other
|
||||||
predicts only an image. """
|
predicts only an image. """
|
||||||
|
@ -103,37 +102,19 @@ class Convert():
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
""" Load the model requested for conversion """
|
""" Load the model requested for conversion """
|
||||||
model_name = self.args.trainer
|
logger.debug("Loading Model")
|
||||||
model_dir = get_folder(self.args.model_dir)
|
model_dir = get_folder(self.args.model_dir)
|
||||||
num_gpus = self.args.gpus
|
model = PluginLoader.get_model(self.args.trainer)(model_dir, self.args.gpus, predict=True)
|
||||||
|
logger.debug("Loaded Model")
|
||||||
model = PluginLoader.get_model(model_name)(model_dir, num_gpus)
|
|
||||||
|
|
||||||
if not model.load(self.args.swap_model):
|
|
||||||
logger.error("Model Not Found! A valid model "
|
|
||||||
"must be provided to continue!")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_converter(self, model):
|
def load_converter(self, model):
|
||||||
""" Load the requested converter for conversion """
|
""" Load the requested converter for conversion """
|
||||||
args = self.args
|
conv = self.args.converter
|
||||||
conv = args.converter
|
|
||||||
|
|
||||||
converter = PluginLoader.get_converter(conv)(
|
converter = PluginLoader.get_converter(conv)(
|
||||||
model.converter(False),
|
model.converter(self.args.swap_model),
|
||||||
trainer=args.trainer,
|
model=model,
|
||||||
blur_size=args.blur_size,
|
arguments=self.args)
|
||||||
seamless_clone=args.seamless_clone,
|
|
||||||
sharpen_image=args.sharpen_image,
|
|
||||||
mask_type=args.mask_type,
|
|
||||||
erosion_kernel_size=args.erosion_kernel_size,
|
|
||||||
match_histogram=args.match_histogram,
|
|
||||||
smooth_mask=args.smooth_mask,
|
|
||||||
avg_color_adjust=args.avg_color_adjust,
|
|
||||||
draw_transparent=args.draw_transparent)
|
|
||||||
|
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
def prepare_images(self):
|
def prepare_images(self):
|
||||||
|
@ -205,25 +186,13 @@ class Convert():
|
||||||
|
|
||||||
if not skip:
|
if not skip:
|
||||||
for face in faces:
|
for face in faces:
|
||||||
image = self.convert_one_face(converter, image, face)
|
image = converter.patch_image(image, face)
|
||||||
filename = str(self.output_dir / Path(filename).name)
|
filename = str(self.output_dir / Path(filename).name)
|
||||||
cv2.imwrite(filename, image) # pylint: disable=no-member
|
cv2.imwrite(filename, image) # pylint: disable=no-member
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error("Failed to convert image: '%s'. Reason: %s", filename, err)
|
logger.error("Failed to convert image: '%s'. Reason: %s", filename, err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def convert_one_face(self, converter, image, face):
|
|
||||||
""" Perform the conversion on the given frame for a single face """
|
|
||||||
# TODO: This switch between 64 and 128 is a hack for now.
|
|
||||||
# We should have a separate cli option for size
|
|
||||||
size = 128 if (self.args.trainer.strip().lower()
|
|
||||||
in ('gan128', 'originalhighres')) else 64
|
|
||||||
|
|
||||||
image = converter.patch_image(image,
|
|
||||||
face,
|
|
||||||
size)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class OptionalActions():
|
class OptionalActions():
|
||||||
""" Process the optional actions for convert """
|
""" Process the optional actions for convert """
|
||||||
|
@ -305,10 +274,8 @@ class OptionalActions():
|
||||||
|
|
||||||
class Legacy():
|
class Legacy():
|
||||||
""" Update legacy alignments:
|
""" Update legacy alignments:
|
||||||
|
|
||||||
- Add frame dimensions
|
|
||||||
- Rotate landmarks and bounding boxes on legacy alignments
|
- Rotate landmarks and bounding boxes on legacy alignments
|
||||||
and remove the 'r' parameter
|
and remove the 'r' parameter
|
||||||
- Add face hashes to alignments file
|
- Add face hashes to alignments file
|
||||||
"""
|
"""
|
||||||
def __init__(self, alignments, frames, faces_dir):
|
def __init__(self, alignments, frames, faces_dir):
|
||||||
|
@ -319,15 +286,10 @@ class Legacy():
|
||||||
|
|
||||||
def process(self, faces_dir):
|
def process(self, faces_dir):
|
||||||
""" Run the rotate alignments process """
|
""" Run the rotate alignments process """
|
||||||
no_dims = self.alignments.get_legacy_no_dims()
|
|
||||||
rotated = self.alignments.get_legacy_rotation()
|
rotated = self.alignments.get_legacy_rotation()
|
||||||
hashes = self.alignments.get_legacy_no_hashes()
|
hashes = self.alignments.get_legacy_no_hashes()
|
||||||
if not no_dims and not rotated and not hashes:
|
if not rotated and not hashes:
|
||||||
return
|
return
|
||||||
if no_dims:
|
|
||||||
logger.info("Legacy landmarks found. Adding frame dimensions...")
|
|
||||||
self.add_dimensions(no_dims)
|
|
||||||
self.alignments.save()
|
|
||||||
if rotated:
|
if rotated:
|
||||||
logger.info("Legacy rotated frames found. Converting...")
|
logger.info("Legacy rotated frames found. Converting...")
|
||||||
self.rotate_landmarks(rotated)
|
self.rotate_landmarks(rotated)
|
||||||
|
@ -337,22 +299,14 @@ class Legacy():
|
||||||
self.add_hashes(hashes, faces_dir)
|
self.add_hashes(hashes, faces_dir)
|
||||||
self.alignments.save()
|
self.alignments.save()
|
||||||
|
|
||||||
def add_dimensions(self, no_dims):
|
|
||||||
""" Add width and height of original frame to alignments """
|
|
||||||
for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"):
|
|
||||||
if no_dim not in self.frames.keys():
|
|
||||||
continue
|
|
||||||
filename = self.frames[no_dim]
|
|
||||||
dims = cv2.imread(filename).shape[:2] # pylint: disable=no-member
|
|
||||||
self.alignments.add_dimensions(no_dim, dims)
|
|
||||||
|
|
||||||
def rotate_landmarks(self, rotated):
|
def rotate_landmarks(self, rotated):
|
||||||
""" Rotate the landmarks """
|
""" Rotate the landmarks """
|
||||||
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
|
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
|
||||||
if rotate_item not in self.frames.keys():
|
frame = self.frames.get(rotate_item, None)
|
||||||
|
if frame is None:
|
||||||
logger.debug("Skipping missing frame: '%s'", rotate_item)
|
logger.debug("Skipping missing frame: '%s'", rotate_item)
|
||||||
continue
|
continue
|
||||||
self.alignments.rotate_existing_landmarks(rotate_item)
|
self.alignments.rotate_existing_landmarks(rotate_item, frame)
|
||||||
|
|
||||||
def add_hashes(self, hashes, faces_dir):
|
def add_hashes(self, hashes, faces_dir):
|
||||||
""" Add Face Hashes to the alignments file """
|
""" Add Face Hashes to the alignments file """
|
||||||
|
|
|
@ -54,7 +54,7 @@ class Extract():
|
||||||
self.verify_output)
|
self.verify_output)
|
||||||
|
|
||||||
def threaded_io(self, task, io_args=None):
|
def threaded_io(self, task, io_args=None):
|
||||||
""" Load images in a background thread """
|
""" Perform I/O task in a background thread """
|
||||||
logger.debug("Threading task: (Task: '%s')", task)
|
logger.debug("Threading task: (Task: '%s')", task)
|
||||||
io_args = tuple() if io_args is None else (io_args, )
|
io_args = tuple() if io_args is None else (io_args, )
|
||||||
if task == "load":
|
if task == "load":
|
||||||
|
@ -211,7 +211,7 @@ class Extract():
|
||||||
|
|
||||||
self.threaded_io("reload", detected_faces)
|
self.threaded_io("reload", detected_faces)
|
||||||
|
|
||||||
def align_face(self, faces, align_eyes, size, filename, padding=48):
|
def align_face(self, faces, align_eyes, size, filename):
|
||||||
""" Align the detected face and add the destination file path """
|
""" Align the detected face and add the destination file path """
|
||||||
final_faces = list()
|
final_faces = list()
|
||||||
image = faces["image"]
|
image = faces["image"]
|
||||||
|
@ -221,11 +221,7 @@ class Extract():
|
||||||
detected_face = DetectedFace()
|
detected_face = DetectedFace()
|
||||||
detected_face.from_dlib_rect(face, image)
|
detected_face.from_dlib_rect(face, image)
|
||||||
detected_face.landmarksXY = landmarks[idx]
|
detected_face.landmarksXY = landmarks[idx]
|
||||||
detected_face.frame_dims = image.shape[:2]
|
detected_face.load_aligned(image, size=size, align_eyes=align_eyes)
|
||||||
detected_face.load_aligned(image,
|
|
||||||
size=size,
|
|
||||||
padding=padding,
|
|
||||||
align_eyes=align_eyes)
|
|
||||||
final_faces.append({"file_location": self.output_dir / Path(filename).stem,
|
final_faces.append({"file_location": self.output_dir / Path(filename).stem,
|
||||||
"face": detected_face})
|
"face": detected_face})
|
||||||
faces["detected_faces"] = final_faces
|
faces["detected_faces"] = final_faces
|
||||||
|
@ -262,7 +258,7 @@ class Plugins():
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def set_parallel_processing(self):
|
def set_parallel_processing(self):
|
||||||
""" Set whether to run detect and align together or seperately """
|
""" Set whether to run detect and align together or separately """
|
||||||
detector_vram = self.detector.vram
|
detector_vram = self.detector.vram
|
||||||
aligner_vram = self.aligner.vram
|
aligner_vram = self.aligner.vram
|
||||||
gpu_stats = GPUStats()
|
gpu_stats = GPUStats()
|
||||||
|
@ -356,11 +352,6 @@ class Plugins():
|
||||||
kwargs = {"in_queue": queue_manager.get_queue("load"),
|
kwargs = {"in_queue": queue_manager.get_queue("load"),
|
||||||
"out_queue": out_queue}
|
"out_queue": out_queue}
|
||||||
|
|
||||||
if self.args.detector == "mtcnn":
|
|
||||||
mtcnn_kwargs = self.detector.validate_kwargs(
|
|
||||||
self.get_mtcnn_kwargs())
|
|
||||||
kwargs["mtcnn_kwargs"] = mtcnn_kwargs
|
|
||||||
|
|
||||||
mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
|
mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
|
||||||
self.process_detect = mp_func(self.detector.run, **kwargs)
|
self.process_detect = mp_func(self.detector.run, **kwargs)
|
||||||
|
|
||||||
|
@ -384,14 +375,6 @@ class Plugins():
|
||||||
|
|
||||||
logger.debug("Launched Detector")
|
logger.debug("Launched Detector")
|
||||||
|
|
||||||
def get_mtcnn_kwargs(self):
|
|
||||||
""" Add the mtcnn arguments into a kwargs dictionary """
|
|
||||||
mtcnn_threshold = [float(thr.strip())
|
|
||||||
for thr in self.args.mtcnn_threshold]
|
|
||||||
return {"minsize": self.args.mtcnn_minsize,
|
|
||||||
"threshold": mtcnn_threshold,
|
|
||||||
"factor": self.args.mtcnn_scalefactor}
|
|
||||||
|
|
||||||
def detect_faces(self, extract_pass="detect"):
|
def detect_faces(self, extract_pass="detect"):
|
||||||
""" Detect faces from in an image """
|
""" Detect faces from in an image """
|
||||||
logger.debug("Running Detection. Pass: '%s'", extract_pass)
|
logger.debug("Running Detection. Pass: '%s'", extract_pass)
|
||||||
|
|
|
@ -180,7 +180,7 @@ class Images():
|
||||||
|
|
||||||
def load_disk_frames(self):
|
def load_disk_frames(self):
|
||||||
""" Load frames from disk """
|
""" Load frames from disk """
|
||||||
logger.debug("Input is Seperate Frames. Loading images")
|
logger.debug("Input is separate Frames. Loading images")
|
||||||
for filename in self.input_images:
|
for filename in self.input_images:
|
||||||
logger.trace("Loading image: '%s'", filename)
|
logger.trace("Loading image: '%s'", filename)
|
||||||
try:
|
try:
|
||||||
|
@ -314,9 +314,10 @@ class BlurryFaceFilter(PostProcessAction): # pylint: disable=too-few-public-met
|
||||||
aligned_landmarks = face.aligned_landmarks
|
aligned_landmarks = face.aligned_landmarks
|
||||||
resized_face = face.aligned_face
|
resized_face = face.aligned_face
|
||||||
size = face.aligned["size"]
|
size = face.aligned["size"]
|
||||||
|
padding = int(size * 0.1875)
|
||||||
feature_mask = extractor.get_feature_mask(
|
feature_mask = extractor.get_feature_mask(
|
||||||
aligned_landmarks / size,
|
aligned_landmarks / size,
|
||||||
size, 48)
|
size, padding)
|
||||||
feature_mask = cv2.blur( # pylint: disable=no-member
|
feature_mask = cv2.blur( # pylint: disable=no-member
|
||||||
feature_mask, (10, 10))
|
feature_mask, (10, 10))
|
||||||
isolated_face = cv2.multiply( # pylint: disable=no-member
|
isolated_face = cv2.multiply( # pylint: disable=no-member
|
||||||
|
|
111
scripts/gui.py
111
scripts/gui.py
|
@ -1,101 +1,77 @@
|
||||||
#!/usr/bin python3
|
#!/usr/bin python3
|
||||||
""" The optional GUI for faceswap """
|
""" The optional GUI for faceswap """
|
||||||
|
|
||||||
# NB: The GUI can't currently log as it is a wrapper for the python scripts, so don't
|
import logging
|
||||||
# implement logging unless you can handle the conflicts
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
|
|
||||||
from tkinter import messagebox, ttk
|
from tkinter import messagebox, ttk
|
||||||
|
|
||||||
from lib.gui import (CliOptions, CurrentSession, CommandNotebook, Config,
|
from lib.gui import (CliOptions, CommandNotebook, ConsoleOut, Session, DisplayNotebook,
|
||||||
ConsoleOut, DisplayNotebook, Images, ProcessWrapper,
|
get_config, get_images, initialize_images, initialize_config, MainMenuBar,
|
||||||
StatusBar)
|
ProcessWrapper, StatusBar)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class FaceswapGui(tk.Tk):
|
class FaceswapGui(tk.Tk):
|
||||||
""" The Graphical User Interface """
|
""" The Graphical User Interface """
|
||||||
|
|
||||||
def __init__(self, pathscript):
|
def __init__(self, pathscript):
|
||||||
tk.Tk.__init__(self)
|
logger.debug("Initializing %s", self.__class__.__name__)
|
||||||
self.scaling_factor = self.get_scaling()
|
super().__init__()
|
||||||
|
|
||||||
|
self.initialize_globals(pathscript)
|
||||||
self.set_geometry()
|
self.set_geometry()
|
||||||
|
self.wrapper = ProcessWrapper(pathscript)
|
||||||
|
|
||||||
pathcache = os.path.join(pathscript, "lib", "gui", ".cache")
|
get_images().delete_preview()
|
||||||
self.images = Images(pathcache)
|
|
||||||
self.cliopts = CliOptions()
|
|
||||||
self.session = CurrentSession()
|
|
||||||
statusbar = StatusBar(self)
|
|
||||||
self.wrapper = ProcessWrapper(statusbar,
|
|
||||||
self.session,
|
|
||||||
pathscript,
|
|
||||||
self.cliopts)
|
|
||||||
|
|
||||||
self.images.delete_preview()
|
|
||||||
self.protocol("WM_DELETE_WINDOW", self.close_app)
|
self.protocol("WM_DELETE_WINDOW", self.close_app)
|
||||||
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
def initialize_globals(self, pathscript):
|
||||||
|
""" Initialize config and images global constants """
|
||||||
|
cliopts = CliOptions()
|
||||||
|
scaling_factor = self.get_scaling()
|
||||||
|
pathcache = os.path.join(pathscript, "lib", "gui", ".cache")
|
||||||
|
statusbar = StatusBar(self)
|
||||||
|
session = Session()
|
||||||
|
initialize_config(cliopts, scaling_factor, pathcache, statusbar, session)
|
||||||
|
initialize_images()
|
||||||
|
|
||||||
def get_scaling(self):
|
def get_scaling(self):
|
||||||
""" Get the display DPI """
|
""" Get the display DPI """
|
||||||
dpi = self.winfo_fpixels("1i")
|
dpi = self.winfo_fpixels("1i")
|
||||||
return dpi / 72.0
|
scaling = dpi / 72.0
|
||||||
|
logger.debug("dpi: %s, scaling: %s'", dpi, scaling)
|
||||||
|
return scaling
|
||||||
|
|
||||||
def set_geometry(self):
|
def set_geometry(self):
|
||||||
""" Set GUI geometry """
|
""" Set GUI geometry """
|
||||||
self.tk.call("tk", "scaling", self.scaling_factor)
|
scaling_factor = get_config().scaling_factor
|
||||||
width = int(1200 * self.scaling_factor)
|
self.tk.call("tk", "scaling", scaling_factor)
|
||||||
height = int(640 * self.scaling_factor)
|
width = int(1200 * scaling_factor)
|
||||||
|
height = int(640 * scaling_factor)
|
||||||
|
logger.debug("Geometry: %sx%s", width, height)
|
||||||
self.geometry("{}x{}+80+80".format(str(width), str(height)))
|
self.geometry("{}x{}+80+80".format(str(width), str(height)))
|
||||||
|
|
||||||
def build_gui(self, debug_console):
|
def build_gui(self, debug_console):
|
||||||
""" Build the GUI """
|
""" Build the GUI """
|
||||||
|
logger.debug("Building GUI")
|
||||||
self.title("Faceswap.py")
|
self.title("Faceswap.py")
|
||||||
self.menu()
|
self.configure(menu=MainMenuBar(self))
|
||||||
|
|
||||||
topcontainer, bottomcontainer = self.add_containers()
|
topcontainer, bottomcontainer = self.add_containers()
|
||||||
|
|
||||||
CommandNotebook(topcontainer,
|
CommandNotebook(topcontainer)
|
||||||
self.cliopts,
|
DisplayNotebook(topcontainer)
|
||||||
self.wrapper.tk_vars,
|
ConsoleOut(bottomcontainer, debug_console)
|
||||||
self.scaling_factor)
|
logger.debug("Built GUI")
|
||||||
DisplayNotebook(topcontainer,
|
|
||||||
self.session,
|
|
||||||
self.wrapper.tk_vars,
|
|
||||||
self.scaling_factor)
|
|
||||||
ConsoleOut(bottomcontainer, debug_console, self.wrapper.tk_vars)
|
|
||||||
|
|
||||||
def menu(self):
|
|
||||||
""" Menu bar for loading and saving configs """
|
|
||||||
menubar = tk.Menu(self)
|
|
||||||
filemenu = tk.Menu(menubar, tearoff=0)
|
|
||||||
|
|
||||||
config = Config(self.cliopts, self.wrapper.tk_vars)
|
|
||||||
|
|
||||||
filemenu.add_command(label="Load full config...",
|
|
||||||
underline=0,
|
|
||||||
command=config.load)
|
|
||||||
filemenu.add_command(label="Save full config...",
|
|
||||||
underline=0,
|
|
||||||
command=config.save)
|
|
||||||
filemenu.add_separator()
|
|
||||||
filemenu.add_command(label="Reset all to default",
|
|
||||||
underline=0,
|
|
||||||
command=self.cliopts.reset)
|
|
||||||
filemenu.add_command(label="Clear all",
|
|
||||||
underline=0,
|
|
||||||
command=self.cliopts.clear)
|
|
||||||
filemenu.add_separator()
|
|
||||||
filemenu.add_command(label="Quit",
|
|
||||||
underline=0,
|
|
||||||
command=self.close_app)
|
|
||||||
|
|
||||||
menubar.add_cascade(label="File", menu=filemenu, underline=0)
|
|
||||||
self.config(menu=menubar)
|
|
||||||
|
|
||||||
def add_containers(self):
|
def add_containers(self):
|
||||||
""" Add the paned window containers that
|
""" Add the paned window containers that
|
||||||
hold each main area of the gui """
|
hold each main area of the gui """
|
||||||
|
logger.debug("Adding containers")
|
||||||
maincontainer = tk.PanedWindow(self,
|
maincontainer = tk.PanedWindow(self,
|
||||||
sashrelief=tk.RAISED,
|
sashrelief=tk.RAISED,
|
||||||
orient=tk.VERTICAL)
|
orient=tk.VERTICAL)
|
||||||
|
@ -109,21 +85,26 @@ class FaceswapGui(tk.Tk):
|
||||||
bottomcontainer = ttk.Frame(maincontainer, height=150)
|
bottomcontainer = ttk.Frame(maincontainer, height=150)
|
||||||
maincontainer.add(bottomcontainer)
|
maincontainer.add(bottomcontainer)
|
||||||
|
|
||||||
|
logger.debug("Added containers")
|
||||||
return topcontainer, bottomcontainer
|
return topcontainer, bottomcontainer
|
||||||
|
|
||||||
def close_app(self):
|
def close_app(self):
|
||||||
""" Close Python. This is here because the graph
|
""" Close Python. This is here because the graph
|
||||||
animation function continues to run even when
|
animation function continues to run even when
|
||||||
tkinter has gone away """
|
tkinter has gone away """
|
||||||
|
logger.debug("Close Requested")
|
||||||
confirm = messagebox.askokcancel
|
confirm = messagebox.askokcancel
|
||||||
confirmtxt = "Processes are still running. Are you sure...?"
|
confirmtxt = "Processes are still running. Are you sure...?"
|
||||||
if (self.wrapper.tk_vars["runningtask"].get()
|
tk_vars = get_config().tk_vars
|
||||||
|
if (tk_vars["runningtask"].get()
|
||||||
and not confirm("Close", confirmtxt)):
|
and not confirm("Close", confirmtxt)):
|
||||||
|
logger.debug("Close Cancelled")
|
||||||
return
|
return
|
||||||
if self.wrapper.tk_vars["runningtask"].get():
|
if tk_vars["runningtask"].get():
|
||||||
self.wrapper.task.terminate()
|
self.wrapper.task.terminate()
|
||||||
self.images.delete_preview()
|
get_images().delete_preview()
|
||||||
self.quit()
|
self.quit()
|
||||||
|
logger.debug("Closed GUI")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
|
296
scripts/train.py
296
scripts/train.py
|
@ -4,14 +4,18 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
|
from threading import Lock
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.backend.tensorflow_backend import set_session
|
from keras.backend.tensorflow_backend import set_session
|
||||||
|
|
||||||
from lib.utils import (get_folder, get_image_paths, set_system_verbosity,
|
from lib.keypress import KBHit
|
||||||
Timelapse)
|
from lib.multithreading import MultiThread
|
||||||
|
from lib.queue_manager import queue_manager
|
||||||
|
from lib.utils import (get_folder, get_image_paths, set_system_verbosity)
|
||||||
from plugins.plugin_loader import PluginLoader
|
from plugins.plugin_loader import PluginLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -20,38 +24,48 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
class Train():
|
class Train():
|
||||||
""" The training process. """
|
""" The training process. """
|
||||||
def __init__(self, arguments):
|
def __init__(self, arguments):
|
||||||
|
logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
|
||||||
self.args = arguments
|
self.args = arguments
|
||||||
|
self.timelapse = self.set_timelapse()
|
||||||
self.images = self.get_images()
|
self.images = self.get_images()
|
||||||
self.stop = False
|
self.stop = False
|
||||||
self.save_now = False
|
self.save_now = False
|
||||||
self.preview_buffer = dict()
|
self.preview_buffer = dict()
|
||||||
self.lock = threading.Lock()
|
self.lock = Lock()
|
||||||
|
|
||||||
# this is so that you can enter case insensitive values for trainer
|
self.trainer_name = self.args.trainer
|
||||||
trainer_name = self.args.trainer
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
self.trainer_name = trainer_name
|
|
||||||
if trainer_name.lower() == "lowmem":
|
|
||||||
self.trainer_name = "LowMem"
|
|
||||||
self.timelapse = None
|
|
||||||
|
|
||||||
def process(self):
|
def set_timelapse(self):
|
||||||
""" Call the training process object """
|
""" Set timelapse paths if requested """
|
||||||
logger.info("Training data directory: %s", self.args.model_dir)
|
if (not self.args.timelapse_input_a and
|
||||||
set_system_verbosity(self.args.loglevel)
|
not self.args.timelapse_input_b and
|
||||||
thread = self.start_thread()
|
not self.args.timelapse_output):
|
||||||
|
return None
|
||||||
|
if not self.args.timelapse_input_a or not self.args.timelapse_input_b:
|
||||||
|
raise ValueError("To enable the timelapse, you have to supply "
|
||||||
|
"all the parameters (--timelapse-input-A and "
|
||||||
|
"--timelapse-input-B).")
|
||||||
|
|
||||||
if self.args.preview:
|
for folder in (self.args.timelapse_input_a,
|
||||||
self.monitor_preview()
|
self.args.timelapse_input_b,
|
||||||
else:
|
self.args.timelapse_output):
|
||||||
self.monitor_console()
|
if folder is not None and not os.path.isdir(folder):
|
||||||
|
raise ValueError("The Timelapse path '{}' does not exist".format(folder))
|
||||||
|
|
||||||
self.end_thread(thread)
|
kwargs = {"input_a": self.args.timelapse_input_a,
|
||||||
|
"input_b": self.args.timelapse_input_b,
|
||||||
|
"output": self.args.timelapse_output}
|
||||||
|
logger.debug("Timelapse enabled: %s", kwargs)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def get_images(self):
|
def get_images(self):
|
||||||
""" Check the image dirs exist, contain images and return the image
|
""" Check the image dirs exist, contain images and return the image
|
||||||
objects """
|
objects """
|
||||||
images = []
|
logger.debug("Getting image paths")
|
||||||
for image_dir in [self.args.input_A, self.args.input_B]:
|
images = dict()
|
||||||
|
for side in ("a", "b"):
|
||||||
|
image_dir = getattr(self.args, "input_{}".format(side))
|
||||||
if not os.path.isdir(image_dir):
|
if not os.path.isdir(image_dir):
|
||||||
logger.error("Error: '%s' does not exist", image_dir)
|
logger.error("Error: '%s' does not exist", image_dir)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
@ -60,30 +74,60 @@ class Train():
|
||||||
logger.error("Error: '%s' contains no images", image_dir)
|
logger.error("Error: '%s' contains no images", image_dir)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
images.append(get_image_paths(image_dir))
|
images[side] = get_image_paths(image_dir)
|
||||||
logger.info("Model A Directory: %s", self.args.input_A)
|
logger.info("Model A Directory: %s", self.args.input_a)
|
||||||
logger.info("Model B Directory: %s", self.args.input_B)
|
logger.info("Model B Directory: %s", self.args.input_b)
|
||||||
|
logger.debug("Got image paths: %s", [(key, str(len(val)) + " images")
|
||||||
|
for key, val in images.items()])
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
""" Call the training process object """
|
||||||
|
logger.debug("Starting Training Process")
|
||||||
|
logger.info("Training data directory: %s", self.args.model_dir)
|
||||||
|
set_system_verbosity(self.args.loglevel)
|
||||||
|
thread = self.start_thread()
|
||||||
|
# queue_manager.debug_monitor(1)
|
||||||
|
|
||||||
|
if self.args.preview:
|
||||||
|
err = self.monitor_preview(thread)
|
||||||
|
else:
|
||||||
|
err = self.monitor_console(thread)
|
||||||
|
|
||||||
|
self.end_thread(thread, err)
|
||||||
|
logger.debug("Completed Training Process")
|
||||||
|
|
||||||
def start_thread(self):
|
def start_thread(self):
|
||||||
""" Put the training process in a thread so we can keep control """
|
""" Put the training process in a thread so we can keep control """
|
||||||
thread = threading.Thread(target=self.process_thread)
|
logger.debug("Launching Trainer thread")
|
||||||
|
thread = MultiThread(target=self.training)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
logger.debug("Launched Trainer thread")
|
||||||
return thread
|
return thread
|
||||||
|
|
||||||
def end_thread(self, thread):
|
def end_thread(self, thread, err):
|
||||||
""" On termination output message and join thread back to main """
|
""" On termination output message and join thread back to main """
|
||||||
logger.info("Exit requested! The trainer will complete its current cycle, "
|
logger.debug("Ending Training thread")
|
||||||
"save the models and quit (it can take up a couple of seconds "
|
if err:
|
||||||
"depending on your training speed). If you want to kill it now, "
|
msg = "Error caught! Exiting..."
|
||||||
"press Ctrl + c")
|
log = logger.critical
|
||||||
|
else:
|
||||||
|
msg = ("Exit requested! The trainer will complete its current cycle, "
|
||||||
|
"save the models and quit (it can take up a couple of seconds "
|
||||||
|
"depending on your training speed). If you want to kill it now, "
|
||||||
|
"press Ctrl + c")
|
||||||
|
log = logger.info
|
||||||
|
log(msg)
|
||||||
self.stop = True
|
self.stop = True
|
||||||
thread.join()
|
thread.join()
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
logger.debug("Ended Training thread")
|
||||||
|
|
||||||
def process_thread(self):
|
def training(self):
|
||||||
""" The training process to be run inside a thread """
|
""" The training process to be run inside a thread """
|
||||||
try:
|
try:
|
||||||
|
sleep(1) # Let preview instructions flush out to logger
|
||||||
|
logger.debug("Commencing Training")
|
||||||
logger.info("Loading data, this may take a while...")
|
logger.info("Loading data, this may take a while...")
|
||||||
|
|
||||||
if self.args.allow_growth:
|
if self.args.allow_growth:
|
||||||
|
@ -91,17 +135,12 @@ class Train():
|
||||||
|
|
||||||
model = self.load_model()
|
model = self.load_model()
|
||||||
trainer = self.load_trainer(model)
|
trainer = self.load_trainer(model)
|
||||||
|
|
||||||
self.timelapse = Timelapse.create_timelapse(
|
|
||||||
self.args.timelapse_input_A,
|
|
||||||
self.args.timelapse_input_B,
|
|
||||||
self.args.timelapse_output,
|
|
||||||
trainer)
|
|
||||||
|
|
||||||
self.run_training_cycle(model, trainer)
|
self.run_training_cycle(model, trainer)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
try:
|
try:
|
||||||
model.save_weights()
|
logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting")
|
||||||
|
model.save_models()
|
||||||
|
trainer.clear_tensorboard()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Saving model weights has been cancelled!")
|
logger.info("Saving model weights has been cancelled!")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
@ -110,105 +149,192 @@ class Train():
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
""" Load the model requested for training """
|
""" Load the model requested for training """
|
||||||
|
logger.debug("Loading Model")
|
||||||
model_dir = get_folder(self.args.model_dir)
|
model_dir = get_folder(self.args.model_dir)
|
||||||
model = PluginLoader.get_model(self.trainer_name)(model_dir,
|
model = PluginLoader.get_model(self.trainer_name)(
|
||||||
self.args.gpus)
|
model_dir,
|
||||||
|
self.args.gpus,
|
||||||
model.load(swapped=False)
|
no_logs=self.args.no_logs,
|
||||||
|
warp_to_landmarks=self.args.warp_to_landmarks,
|
||||||
|
no_flip=self.args.no_flip,
|
||||||
|
training_image_size=self.image_size,
|
||||||
|
alignments_paths=self.alignments_paths,
|
||||||
|
preview_scale=self.args.preview_scale)
|
||||||
|
logger.debug("Loaded Model")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
""" Get the training set image size for storing in model data """
|
||||||
|
image = cv2.imread(self.images["a"][0]) # pylint: disable=no-member
|
||||||
|
size = image.shape[0]
|
||||||
|
logger.debug("Training image size: %s", size)
|
||||||
|
return size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alignments_paths(self):
|
||||||
|
""" Set the alignments path to input dirs if not provided """
|
||||||
|
alignments_paths = dict()
|
||||||
|
for side in ("a", "b"):
|
||||||
|
alignments_path = getattr(self.args, "alignments_path_{}".format(side))
|
||||||
|
if not alignments_path:
|
||||||
|
image_path = getattr(self.args, "input_{}".format(side))
|
||||||
|
alignments_path = os.path.join(image_path, "alignments.json")
|
||||||
|
alignments_paths[side] = alignments_path
|
||||||
|
logger.debug("Alignments paths: %s", alignments_paths)
|
||||||
|
return alignments_paths
|
||||||
|
|
||||||
def load_trainer(self, model):
|
def load_trainer(self, model):
|
||||||
""" Load the trainer requested for training """
|
""" Load the trainer requested for training """
|
||||||
images_a, images_b = self.images
|
logger.debug("Loading Trainer")
|
||||||
|
trainer = PluginLoader.get_trainer(model.trainer)
|
||||||
trainer = PluginLoader.get_trainer(self.trainer_name)
|
|
||||||
trainer = trainer(model,
|
trainer = trainer(model,
|
||||||
images_a,
|
self.images,
|
||||||
images_b,
|
self.args.batch_size)
|
||||||
self.args.batch_size,
|
logger.debug("Loaded Trainer")
|
||||||
self.args.perceptual_loss)
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def run_training_cycle(self, model, trainer):
|
def run_training_cycle(self, model, trainer):
|
||||||
""" Perform the training cycle """
|
""" Perform the training cycle """
|
||||||
|
logger.debug("Running Training Cycle")
|
||||||
|
if self.args.write_image or self.args.redirect_gui or self.args.preview:
|
||||||
|
display_func = self.show
|
||||||
|
else:
|
||||||
|
display_func = None
|
||||||
|
|
||||||
for iteration in range(0, self.args.iterations):
|
for iteration in range(0, self.args.iterations):
|
||||||
|
logger.trace("Training iteration: %s", iteration)
|
||||||
save_iteration = iteration % self.args.save_interval == 0
|
save_iteration = iteration % self.args.save_interval == 0
|
||||||
viewer = self.show if save_iteration or self.save_now else None
|
viewer = display_func if save_iteration or self.save_now else None
|
||||||
if save_iteration and self.timelapse is not None:
|
timelapse = self.timelapse if save_iteration else None
|
||||||
self.timelapse.work()
|
trainer.train_one_step(viewer, timelapse)
|
||||||
trainer.train_one_step(iteration, viewer)
|
|
||||||
if self.stop:
|
if self.stop:
|
||||||
|
logger.debug("Stop received. Terminating")
|
||||||
break
|
break
|
||||||
elif save_iteration:
|
elif save_iteration:
|
||||||
model.save_weights()
|
logger.trace("Save Iteration: (iteration: %s", iteration)
|
||||||
|
model.save_models()
|
||||||
elif self.save_now:
|
elif self.save_now:
|
||||||
model.save_weights()
|
logger.trace("Save Requested: (iteration: %s", iteration)
|
||||||
|
model.save_models()
|
||||||
self.save_now = False
|
self.save_now = False
|
||||||
model.save_weights()
|
logger.debug("Training cycle complete")
|
||||||
|
model.save_models()
|
||||||
|
trainer.clear_tensorboard()
|
||||||
self.stop = True
|
self.stop = True
|
||||||
|
|
||||||
def monitor_preview(self):
|
def monitor_preview(self, thread):
|
||||||
""" Generate the preview window and wait for keyboard input """
|
""" Generate the preview window and wait for keyboard input """
|
||||||
logger.info("Using live preview.\n"
|
logger.debug("Launching Preview Monitor")
|
||||||
"Press 'ENTER' on the preview window to save and quit.\n"
|
logger.info("R|=====================================================================")
|
||||||
"Press 'S' on the preview window to save model weights "
|
logger.info("R|- Using live preview -")
|
||||||
"immediately")
|
logger.info("R|- Press 'ENTER' on the preview window to save and quit -")
|
||||||
|
logger.info("R|- Press 'S' on the preview window to save model weights immediately -")
|
||||||
|
logger.info("R|=====================================================================")
|
||||||
|
err = False
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for name, image in self.preview_buffer.items():
|
for name, image in self.preview_buffer.items():
|
||||||
cv2.imshow(name, image)
|
cv2.imshow(name, image) # pylint: disable=no-member
|
||||||
|
|
||||||
key = cv2.waitKey(1000)
|
key = cv2.waitKey(1000) # pylint: disable=no-member
|
||||||
|
if self.stop:
|
||||||
|
logger.debug("Stop received")
|
||||||
|
break
|
||||||
|
if thread.has_error:
|
||||||
|
logger.debug("Thread error detected")
|
||||||
|
err = True
|
||||||
|
break
|
||||||
if key == ord("\n") or key == ord("\r"):
|
if key == ord("\n") or key == ord("\r"):
|
||||||
|
logger.debug("Exit requested")
|
||||||
break
|
break
|
||||||
if key == ord("s"):
|
if key == ord("s"):
|
||||||
|
logger.info("Save requested")
|
||||||
self.save_now = True
|
self.save_now = True
|
||||||
if self.stop:
|
|
||||||
break
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Keyboard Interrupt received")
|
||||||
break
|
break
|
||||||
|
logger.debug("Closed Preview Monitor")
|
||||||
|
return err
|
||||||
|
|
||||||
|
def monitor_console(self, thread):
|
||||||
|
""" Monitor the console
|
||||||
|
NB: A custom function needs to be used for this because
|
||||||
|
input() blocks """
|
||||||
|
logger.debug("Launching Console Monitor")
|
||||||
|
logger.info("R|===============================================")
|
||||||
|
logger.info("R|- Starting -")
|
||||||
|
logger.info("R|- Press 'ENTER' to save and quit -")
|
||||||
|
logger.info("R|- Press 'S' to save model weights immediately -")
|
||||||
|
logger.info("R|===============================================")
|
||||||
|
keypress = KBHit(is_gui=self.args.redirect_gui)
|
||||||
|
err = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if thread.has_error:
|
||||||
|
logger.debug("Thread error detected")
|
||||||
|
err = True
|
||||||
|
break
|
||||||
|
if self.stop:
|
||||||
|
logger.debug("Stop received")
|
||||||
|
break
|
||||||
|
if keypress.kbhit():
|
||||||
|
key = keypress.getch()
|
||||||
|
if key in ("\n", "\r"):
|
||||||
|
logger.debug("Exit requested")
|
||||||
|
break
|
||||||
|
if key in ("s", "S"):
|
||||||
|
logger.info("Save requested")
|
||||||
|
self.save_now = True
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Keyboard Interrupt received")
|
||||||
|
break
|
||||||
|
keypress.set_normal_term()
|
||||||
|
logger.debug("Closed Console Monitor")
|
||||||
|
return err
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def monitor_console():
|
def keypress_monitor(keypress_queue):
|
||||||
""" Monitor the console for any input followed by enter or ctrl+c """
|
""" Monitor stdin for keypress """
|
||||||
# TODO: how to catch a specific key instead of Enter?
|
while True:
|
||||||
# there isn't a good multiplatform solution:
|
keypress_queue.put(sys.stdin.read(1))
|
||||||
# https://stackoverflow.com/questions/3523174
|
|
||||||
# TODO: Find a way to interrupt input() if the target iterations are
|
|
||||||
# reached. At the moment, setting a target iteration and using the -p
|
|
||||||
# flag is the only guaranteed way to exit the training loop on
|
|
||||||
# hitting target iterations.
|
|
||||||
logger.info("Starting. Press 'ENTER' to stop training and save model")
|
|
||||||
try:
|
|
||||||
input()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_tf_allow_growth():
|
def set_tf_allow_growth():
|
||||||
""" Allow TensorFlow to manage VRAM growth """
|
""" Allow TensorFlow to manage VRAM growth """
|
||||||
|
# pylint: disable=no-member
|
||||||
|
logger.debug("Setting Tensorflow 'allow_growth' option")
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
config.gpu_options.visible_device_list = "0"
|
config.gpu_options.visible_device_list = "0"
|
||||||
set_session(tf.Session(config=config))
|
set_session(tf.Session(config=config))
|
||||||
|
logger.debug("Set Tensorflow 'allow_growth' option")
|
||||||
|
|
||||||
def show(self, image, name=""):
|
def show(self, image, name=""):
|
||||||
""" Generate the preview and write preview file output """
|
""" Generate the preview and write preview file output """
|
||||||
|
logger.trace("Updating preview: (name: %s)", name)
|
||||||
try:
|
try:
|
||||||
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
|
scriptpath = os.path.realpath(os.path.dirname(sys.argv[0]))
|
||||||
if self.args.write_image:
|
if self.args.write_image:
|
||||||
img = "_sample_{}.jpg".format(name)
|
logger.trace("Saving preview to disk")
|
||||||
|
img = "training_preview.jpg"
|
||||||
imgfile = os.path.join(scriptpath, img)
|
imgfile = os.path.join(scriptpath, img)
|
||||||
cv2.imwrite(imgfile, image)
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
||||||
|
logger.trace("Saved preview to: '%s'", img)
|
||||||
if self.args.redirect_gui:
|
if self.args.redirect_gui:
|
||||||
img = ".gui_preview_{}.jpg".format(name)
|
logger.trace("Generating preview for GUI")
|
||||||
|
img = ".gui_training_preview.jpg"
|
||||||
imgfile = os.path.join(scriptpath, "lib", "gui",
|
imgfile = os.path.join(scriptpath, "lib", "gui",
|
||||||
".cache", "preview", img)
|
".cache", "preview", img)
|
||||||
cv2.imwrite(imgfile, image)
|
cv2.imwrite(imgfile, image) # pylint: disable=no-member
|
||||||
|
logger.trace("Generated preview for GUI: '%s'", img)
|
||||||
if self.args.preview:
|
if self.args.preview:
|
||||||
|
logger.trace("Generating preview for display: '%s'", name)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.preview_buffer[name] = image
|
self.preview_buffer[name] = image
|
||||||
|
logger.trace("Generated preview for display: '%s'", name)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.error("could not preview sample")
|
logging.error("could not preview sample")
|
||||||
raise err
|
raise err
|
||||||
|
logger.trace("Updated preview: (name: %s)", name)
|
||||||
|
|
21
tools/cli.py
21
tools/cli.py
|
@ -2,7 +2,7 @@
|
||||||
""" Command Line Arguments for tools """
|
""" Command Line Arguments for tools """
|
||||||
from lib.cli import FaceSwapArgs
|
from lib.cli import FaceSwapArgs
|
||||||
from lib.cli import (ContextFullPaths, DirFullPaths,
|
from lib.cli import (ContextFullPaths, DirFullPaths,
|
||||||
FileFullPaths, SaveFileFullPaths)
|
FileFullPaths, SaveFileFullPaths, Slider)
|
||||||
from lib.utils import _image_extensions
|
from lib.utils import _image_extensions
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,9 +47,9 @@ class AlignmentsArgs(FaceSwapArgs):
|
||||||
"\n\tfile." + output_opts + frames_dir +
|
"\n\tfile." + output_opts + frames_dir +
|
||||||
"\n'missing-frames': Identify frames in the alignments file that do no "
|
"\n'missing-frames': Identify frames in the alignments file that do no "
|
||||||
"\n\tappear within the frames folder/video." + output_opts + frames_dir +
|
"\n\tappear within the frames folder/video." + output_opts + frames_dir +
|
||||||
"\n'legacy': This updates legacy alignments to the latest format by adding"
|
"\n'legacy': This updates legacy alignments to the latest format by rotating"
|
||||||
"\n\tframe dimensions, rotating the landmarks and bounding boxes and adding"
|
"\n\tthe landmarks and bounding boxes and adding face_hashes." +
|
||||||
"\n\tface_hashes" + frames_and_faces_dir +
|
frames_and_faces_dir +
|
||||||
"\n'leftover-faces': Identify faces in the faces folder that do not exist in"
|
"\n'leftover-faces': Identify faces in the faces folder that do not exist in"
|
||||||
"\n\tthe alignments file." + output_opts + faces_dir +
|
"\n\tthe alignments file." + output_opts + faces_dir +
|
||||||
"\n'multi-faces': Identify where multiple faces exist within the alignments"
|
"\n'multi-faces': Identify where multiple faces exist within the alignments"
|
||||||
|
@ -123,6 +123,13 @@ class AlignmentsArgs(FaceSwapArgs):
|
||||||
"\n\tdirectory)."
|
"\n\tdirectory)."
|
||||||
"\n'move': Move the discovered items to a sub-folder within the source"
|
"\n'move': Move the discovered items to a sub-folder within the source"
|
||||||
"\n\tdirectory."})
|
"\n\tdirectory."})
|
||||||
|
argument_list.append({"opts": ("-sz", "--size"),
|
||||||
|
"type": int,
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (128, 512),
|
||||||
|
"default": 256,
|
||||||
|
"rounding": 64,
|
||||||
|
"help": "The output size of extracted faces. (extract only)"})
|
||||||
argument_list.append({"opts": ("-ae", "--align-eyes"),
|
argument_list.append({"opts": ("-ae", "--align-eyes"),
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"dest": "align_eyes",
|
"dest": "align_eyes",
|
||||||
|
@ -409,6 +416,9 @@ class SortArgs(FaceSwapArgs):
|
||||||
"Default: hist"})
|
"Default: hist"})
|
||||||
|
|
||||||
argument_list.append({"opts": ('-t', '--ref_threshold'),
|
argument_list.append({"opts": ('-t', '--ref_threshold'),
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (-1.0, 10.0),
|
||||||
|
"rounding": 2,
|
||||||
"type": float,
|
"type": float,
|
||||||
"dest": 'min_threshold',
|
"dest": 'min_threshold',
|
||||||
"default": -1.0,
|
"default": -1.0,
|
||||||
|
@ -433,6 +443,9 @@ class SortArgs(FaceSwapArgs):
|
||||||
"hist 0.3"})
|
"hist 0.3"})
|
||||||
|
|
||||||
argument_list.append({"opts": ('-b', '--bins'),
|
argument_list.append({"opts": ('-b', '--bins'),
|
||||||
|
"action": Slider,
|
||||||
|
"min_max": (1, 100),
|
||||||
|
"rounding": 1,
|
||||||
"type": int,
|
"type": int,
|
||||||
"dest": 'num_bins',
|
"dest": 'num_bins',
|
||||||
"default": 5,
|
"default": 5,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
""" Tools for manipulating the alignments seralized file """
|
""" Tools for manipulating the alignments serialized file """
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -32,7 +32,7 @@ class Check():
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
def get_source_dir(self, arguments):
|
def get_source_dir(self, arguments):
|
||||||
""" Set the correct source dir """
|
""" Set the correct source folder """
|
||||||
if hasattr(arguments, "faces_dir") and arguments.faces_dir:
|
if hasattr(arguments, "faces_dir") and arguments.faces_dir:
|
||||||
self.type = "faces"
|
self.type = "faces"
|
||||||
source_dir = arguments.faces_dir
|
source_dir = arguments.faces_dir
|
||||||
|
@ -195,7 +195,7 @@ class Check():
|
||||||
os.rename(src, dst)
|
os.rename(src, dst)
|
||||||
|
|
||||||
def move_faces(self, output_folder, items_output):
|
def move_faces(self, output_folder, items_output):
|
||||||
""" Make additional subdirs for each face that appears
|
""" Make additional subfolders for each face that appears
|
||||||
Enables easier manual sorting """
|
Enables easier manual sorting """
|
||||||
logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder)
|
logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder)
|
||||||
for frame, idx in items_output:
|
for frame, idx in items_output:
|
||||||
|
@ -239,7 +239,7 @@ class Draw():
|
||||||
legacy.process()
|
legacy.process()
|
||||||
|
|
||||||
logger.info("[DRAW LANDMARKS]") # Tidy up cli output
|
logger.info("[DRAW LANDMARKS]") # Tidy up cli output
|
||||||
self.extracted_faces = ExtractedFaces(self.frames, self.alignments,
|
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256,
|
||||||
align_eyes=self.arguments.align_eyes)
|
align_eyes=self.arguments.align_eyes)
|
||||||
frames_drawn = 0
|
frames_drawn = 0
|
||||||
for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"):
|
for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"):
|
||||||
|
@ -281,7 +281,7 @@ class Extract():
|
||||||
self.type = arguments.job.replace("extract-", "")
|
self.type = arguments.job.replace("extract-", "")
|
||||||
self.faces_dir = arguments.faces_dir
|
self.faces_dir = arguments.faces_dir
|
||||||
self.frames = Frames(arguments.frames_dir)
|
self.frames = Frames(arguments.frames_dir)
|
||||||
self.extracted_faces = ExtractedFaces(self.frames, self.alignments,
|
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=arguments.size,
|
||||||
align_eyes=arguments.align_eyes)
|
align_eyes=arguments.align_eyes)
|
||||||
logger.debug("Initialized %s", self.__class__.__name__)
|
logger.debug("Initialized %s", self.__class__.__name__)
|
||||||
|
|
||||||
|
@ -352,8 +352,7 @@ class Extract():
|
||||||
valid_faces = faces
|
valid_faces = faces
|
||||||
else:
|
else:
|
||||||
sizes = self.extracted_faces.get_roi_size_for_frame(frame)
|
sizes = self.extracted_faces.get_roi_size_for_frame(frame)
|
||||||
valid_faces = [faces[idx]
|
valid_faces = [faces[idx] for idx, size in enumerate(sizes)
|
||||||
for idx, size in enumerate(sizes)
|
|
||||||
if size >= self.extracted_faces.size]
|
if size >= self.extracted_faces.size]
|
||||||
logger.trace("frame: '%s', total_faces: %s, valid_faces: %s",
|
logger.trace("frame: '%s', total_faces: %s, valid_faces: %s",
|
||||||
frame, len(faces), len(valid_faces))
|
frame, len(faces), len(valid_faces))
|
||||||
|
@ -362,8 +361,6 @@ class Extract():
|
||||||
|
|
||||||
class Legacy():
|
class Legacy():
|
||||||
""" Update legacy alignments:
|
""" Update legacy alignments:
|
||||||
|
|
||||||
- Add frame dimensions
|
|
||||||
- Rotate landmarks and bounding boxes on legacy alignments
|
- Rotate landmarks and bounding boxes on legacy alignments
|
||||||
and remove the 'r' parameter
|
and remove the 'r' parameter
|
||||||
- Add face hashes to alignments file
|
- Add face hashes to alignments file
|
||||||
|
@ -383,16 +380,11 @@ class Legacy():
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
""" Run the rotate alignments process """
|
""" Run the rotate alignments process """
|
||||||
no_dims = self.alignments.get_legacy_no_dims()
|
|
||||||
rotated = self.alignments.get_legacy_rotation()
|
rotated = self.alignments.get_legacy_rotation()
|
||||||
hashes = self.alignments.get_legacy_no_hashes()
|
hashes = self.alignments.get_legacy_no_hashes()
|
||||||
if (not self.frames or (not no_dims and not rotated)) and (not self.faces or not hashes):
|
if (not self.frames or not rotated) and (not self.faces or not hashes):
|
||||||
return
|
return
|
||||||
logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output
|
logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output
|
||||||
if no_dims and self.frames:
|
|
||||||
logger.info("Legacy landmarks found. Adding frame dimensions...")
|
|
||||||
self.add_dimensions(no_dims)
|
|
||||||
self.alignments.save()
|
|
||||||
if rotated and self.frames:
|
if rotated and self.frames:
|
||||||
logger.info("Legacy rotated frames found. Converting...")
|
logger.info("Legacy rotated frames found. Converting...")
|
||||||
self.rotate_landmarks(rotated)
|
self.rotate_landmarks(rotated)
|
||||||
|
@ -402,20 +394,13 @@ class Legacy():
|
||||||
self.add_hashes(hashes)
|
self.add_hashes(hashes)
|
||||||
self.alignments.save()
|
self.alignments.save()
|
||||||
|
|
||||||
def add_dimensions(self, no_dims):
|
|
||||||
""" Add width and height of original frame to alignments """
|
|
||||||
for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"):
|
|
||||||
if no_dim not in self.frames.items.keys():
|
|
||||||
continue
|
|
||||||
dims = self.frames.load_image(no_dim).shape[:2]
|
|
||||||
self.alignments.add_dimensions(no_dim, dims)
|
|
||||||
|
|
||||||
def rotate_landmarks(self, rotated):
|
def rotate_landmarks(self, rotated):
|
||||||
""" Rotate the landmarks """
|
""" Rotate the landmarks """
|
||||||
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
|
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
|
||||||
if rotate_item not in self.frames.items.keys():
|
frame = self.frames.get(rotate_item, None)
|
||||||
|
if frame is None:
|
||||||
continue
|
continue
|
||||||
self.alignments.rotate_existing_landmarks(rotate_item)
|
self.alignments.rotate_existing_landmarks(rotate_item, frame)
|
||||||
|
|
||||||
def add_hashes(self, hashes):
|
def add_hashes(self, hashes):
|
||||||
""" Add Face Hashes to the alignments file """
|
""" Add Face Hashes to the alignments file """
|
||||||
|
@ -838,19 +823,19 @@ class Spatial():
|
||||||
"alignments -j extract -a %s -fr <path_to_frames_dir> -fc "
|
"alignments -j extract -a %s -fr <path_to_frames_dir> -fc "
|
||||||
"<output_folder>", self.arguments.alignments_file)
|
"<output_folder>", self.arguments.alignments_file)
|
||||||
|
|
||||||
# define shape normalization utility functions
|
# Define shape normalization utility functions
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_shapes(shapes_im_coords):
|
def normalize_shapes(shapes_im_coords):
|
||||||
""" Normalize a 2D or 3D shape """
|
""" Normalize a 2D or 3D shape """
|
||||||
logger.debug("Normalize shapes")
|
logger.debug("Normalize shapes")
|
||||||
(num_pts, num_dims, _) = shapes_im_coords.shape
|
(num_pts, num_dims, _) = shapes_im_coords.shape
|
||||||
|
|
||||||
# calc mean coords and subtract from shapes
|
# Calculate mean coordinates and subtract from shapes
|
||||||
mean_coords = shapes_im_coords.mean(axis=0)
|
mean_coords = shapes_im_coords.mean(axis=0)
|
||||||
shapes_centered = np.zeros(shapes_im_coords.shape)
|
shapes_centered = np.zeros(shapes_im_coords.shape)
|
||||||
shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1])
|
shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1])
|
||||||
|
|
||||||
# calc scale factors and divide shapes
|
# Calculate scale factors and divide shapes
|
||||||
scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0)
|
scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0)
|
||||||
shapes_normalized = np.zeros(shapes_centered.shape)
|
shapes_normalized = np.zeros(shapes_centered.shape)
|
||||||
shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1])
|
shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1])
|
||||||
|
@ -889,12 +874,12 @@ class Spatial():
|
||||||
landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1)
|
landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1)
|
||||||
start = end
|
start = end
|
||||||
end = start + landmarks.shape[2]
|
end = start + landmarks.shape[2]
|
||||||
# store in one big array
|
# Store in one big array
|
||||||
landmarks_all[:, :, start:end] = landmarks
|
landmarks_all[:, :, start:end] = landmarks
|
||||||
# make sure we keep track of the mapping to the original frame
|
# Make sure we keep track of the mapping to the original frame
|
||||||
self.mappings[start] = key
|
self.mappings[start] = key
|
||||||
|
|
||||||
# normalize shapes
|
# Normalize shapes
|
||||||
normalized_shape = self.normalize_shapes(landmarks_all)
|
normalized_shape = self.normalize_shapes(landmarks_all)
|
||||||
self.normalized["landmarks"] = normalized_shape[0]
|
self.normalized["landmarks"] = normalized_shape[0]
|
||||||
self.normalized["scale_factors"] = normalized_shape[1]
|
self.normalized["scale_factors"] = normalized_shape[1]
|
||||||
|
@ -920,15 +905,15 @@ class Spatial():
|
||||||
(project and reconstruct) """
|
(project and reconstruct) """
|
||||||
logger.debug("Spatially Filter")
|
logger.debug("Spatially Filter")
|
||||||
landmarks_norm = self.normalized["landmarks"]
|
landmarks_norm = self.normalized["landmarks"]
|
||||||
# convert to matrix form
|
# Convert to matrix form
|
||||||
landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T
|
landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T
|
||||||
# project onto shapes model and reconstruct
|
# Project onto shapes model and reconstruct
|
||||||
landmarks_norm_table_rec = self.shapes_model.inverse_transform(
|
landmarks_norm_table_rec = self.shapes_model.inverse_transform(
|
||||||
self.shapes_model.transform(landmarks_norm_table))
|
self.shapes_model.transform(landmarks_norm_table))
|
||||||
# convert back to shapes (numKeypoint, num_dims, numFrames)
|
# Convert back to shapes (numKeypoint, num_dims, numFrames)
|
||||||
landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T,
|
landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T,
|
||||||
[68, 2, landmarks_norm.shape[2]])
|
[68, 2, landmarks_norm.shape[2]])
|
||||||
# transform back to image coords
|
# Transform back to image coords
|
||||||
retval = self.normalized_to_original(landmarks_norm_rec,
|
retval = self.normalized_to_original(landmarks_norm_rec,
|
||||||
self.normalized["scale_factors"],
|
self.normalized["scale_factors"],
|
||||||
self.normalized["mean_coords"])
|
self.normalized["mean_coords"])
|
||||||
|
|
|
@ -294,7 +294,7 @@ class Interface():
|
||||||
def get_state_color(self):
|
def get_state_color(self):
|
||||||
""" Return a color based on current state
|
""" Return a color based on current state
|
||||||
white - View Mode
|
white - View Mode
|
||||||
yellow - Edit Mide
|
yellow - Edit Mode
|
||||||
red - Unsaved alignments """
|
red - Unsaved alignments """
|
||||||
color = (255, 255, 255)
|
color = (255, 255, 255)
|
||||||
if self.state["edit"]["updated"]:
|
if self.state["edit"]["updated"]:
|
||||||
|
@ -446,7 +446,7 @@ class Manual():
|
||||||
legacy.process()
|
legacy.process()
|
||||||
|
|
||||||
logger.info("[MANUAL PROCESSING]") # Tidy up cli output
|
logger.info("[MANUAL PROCESSING]") # Tidy up cli output
|
||||||
self.extracted_faces = ExtractedFaces(self.frames, self.alignments,
|
self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256,
|
||||||
align_eyes=self.align_eyes)
|
align_eyes=self.align_eyes)
|
||||||
self.interface = Interface(self.alignments, self.frames)
|
self.interface = Interface(self.alignments, self.frames)
|
||||||
self.help = Help(self.interface)
|
self.help = Help(self.interface)
|
||||||
|
@ -510,8 +510,8 @@ class Manual():
|
||||||
MS Windows doesn't appear to read the window state property
|
MS Windows doesn't appear to read the window state property
|
||||||
properly, so we check for a negative key press.
|
properly, so we check for a negative key press.
|
||||||
|
|
||||||
Conda (tested on Windows) doesn't sppear to read the window
|
Conda (tested on Windows) doesn't appear to read the window
|
||||||
state property or negative key press properly, so we arbitarily
|
state property or negative key press properly, so we arbitrarily
|
||||||
use another property """
|
use another property """
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
logger.trace("Commencing closed window check")
|
logger.trace("Commencing closed window check")
|
||||||
|
@ -790,7 +790,7 @@ class MouseHandler():
|
||||||
a_event = align_process.event
|
a_event = align_process.event
|
||||||
align_process.start()
|
align_process.start()
|
||||||
|
|
||||||
# Wait for Aligner to take init
|
# Wait for Aligner to initialize
|
||||||
# The first ever load of the model for FAN has reportedly taken
|
# The first ever load of the model for FAN has reportedly taken
|
||||||
# up to 3-4 minutes, hence high timeout.
|
# up to 3-4 minutes, hence high timeout.
|
||||||
a_event.wait(300)
|
a_event.wait(300)
|
||||||
|
@ -977,7 +977,8 @@ class MouseHandler():
|
||||||
self.interface.state["edit"]["updated"] = True
|
self.interface.state["edit"]["updated"] = True
|
||||||
self.interface.state["edit"]["update_faces"] = True
|
self.interface.state["edit"]["update_faces"] = True
|
||||||
|
|
||||||
def extracted_to_alignment(self, extract_data):
|
@staticmethod
|
||||||
|
def extracted_to_alignment(extract_data):
|
||||||
""" Convert Extracted Tuple to Alignments data """
|
""" Convert Extracted Tuple to Alignments data """
|
||||||
alignment = dict()
|
alignment = dict()
|
||||||
d_rect, landmarks = extract_data
|
d_rect, landmarks = extract_data
|
||||||
|
@ -985,6 +986,5 @@ class MouseHandler():
|
||||||
alignment["w"] = d_rect.right() - d_rect.left()
|
alignment["w"] = d_rect.right() - d_rect.left()
|
||||||
alignment["y"] = d_rect.top()
|
alignment["y"] = d_rect.top()
|
||||||
alignment["h"] = d_rect.bottom() - d_rect.top()
|
alignment["h"] = d_rect.bottom() - d_rect.top()
|
||||||
alignment["frame_dims"] = self.media["image"].shape[:2]
|
|
||||||
alignment["landmarksXY"] = landmarks
|
alignment["landmarksXY"] = landmarks
|
||||||
return alignment
|
return alignment
|
||||||
|
|
|
@ -53,7 +53,7 @@ class AlignmentData(Alignments):
|
||||||
self.set_destination_format(destination_format)
|
self.set_destination_format(destination_format)
|
||||||
|
|
||||||
def set_destination_format(self, destination_format):
|
def set_destination_format(self, destination_format):
|
||||||
""" Standardise the destination format to the correct extension """
|
""" Standardize the destination format to the correct extension """
|
||||||
extensions = {".json": "json",
|
extensions = {".json": "json",
|
||||||
".p": "pickle",
|
".p": "pickle",
|
||||||
".yml": "yaml",
|
".yml": "yaml",
|
||||||
|
@ -274,12 +274,11 @@ class Frames(MediaLoader):
|
||||||
class ExtractedFaces():
|
class ExtractedFaces():
|
||||||
""" Holds the extracted faces and matrix for
|
""" Holds the extracted faces and matrix for
|
||||||
alignments """
|
alignments """
|
||||||
def __init__(self, frames, alignments, size=256,
|
def __init__(self, frames, alignments, size=256, align_eyes=False):
|
||||||
padding=48, align_eyes=False):
|
|
||||||
logger.trace("Initializing %s: (size: %s, padding: %s, align_eyes: %s)",
|
logger.trace("Initializing %s: (size: %s, padding: %s, align_eyes: %s)",
|
||||||
self.__class__.__name__, size, padding, align_eyes)
|
self.__class__.__name__, size, align_eyes)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.padding = padding
|
self.padding = int(size * 0.1875)
|
||||||
self.align_eyes = align_eyes
|
self.align_eyes = align_eyes
|
||||||
self.alignments = alignments
|
self.alignments = alignments
|
||||||
self.frames = frames
|
self.frames = frames
|
||||||
|
@ -309,10 +308,7 @@ class ExtractedFaces():
|
||||||
self.current_frame, alignment)
|
self.current_frame, alignment)
|
||||||
face = DetectedFace()
|
face = DetectedFace()
|
||||||
face.from_alignment(alignment, image=image)
|
face.from_alignment(alignment, image=image)
|
||||||
face.load_aligned(image,
|
face.load_aligned(image, size=self.size, align_eyes=self.align_eyes)
|
||||||
size=self.size,
|
|
||||||
padding=self.padding,
|
|
||||||
align_eyes=self.align_eyes)
|
|
||||||
return face
|
return face
|
||||||
|
|
||||||
def get_faces_in_frame(self, frame, update=False):
|
def get_faces_in_frame(self, frame, update=False):
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Sort():
|
||||||
|
|
||||||
# Assigning default threshold values based on grouping method
|
# Assigning default threshold values based on grouping method
|
||||||
if (self.args.final_process == "folders"
|
if (self.args.final_process == "folders"
|
||||||
and self.args.min_threshold == -1.0):
|
and self.args.min_threshold < 0.0):
|
||||||
method = self.args.group_method.lower()
|
method = self.args.group_method.lower()
|
||||||
if method == 'face':
|
if method == 'face':
|
||||||
self.args.min_threshold = 0.6
|
self.args.min_threshold = 0.6
|
||||||
|
@ -767,9 +767,9 @@ class Sort():
|
||||||
Normalize by pixel number to offset the effect
|
Normalize by pixel number to offset the effect
|
||||||
of image size on pixel gradients & variance
|
of image size on pixel gradients & variance
|
||||||
"""
|
"""
|
||||||
image = cv2.imread(image_file,cv2.IMREAD_GRAYSCALE)
|
image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
|
||||||
blur_map = cv2.Laplacian(image, cv2.CV_32F)
|
blur_map = cv2.Laplacian(image, cv2.CV_32F)
|
||||||
score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1])
|
score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1])
|
||||||
return score
|
return score
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Add table
Reference in a new issue