1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/training_data.py
torzdf cd00859c40
model_refactor (#571) (#572)
* model_refactor (#571)

* original model to new structure

* IAE model to new structure

* OriginalHiRes to new structure

* Fix trainer for different resolutions

* Initial config implementation

* Configparse library added

* improved training data loader

* dfaker model working

* Add logging to training functions

* Non blocking input for cli training

* Add error handling to threads. Add non-mp queues to queue_handler

* Improved Model Building and NNMeta

* refactor lib/models

* training refactor. DFL H128 model Implementation

* Dfaker - use hashes

* Move timelapse. Remove perceptual loss arg

* Update INSTALL.md. Add logger formatting. Update Dfaker training

* DFL h128 partially ported

* Add mask to dfaker (#573)

* Remove old models. Add mask to dfaker

* dfl mask. Make masks selectable in config (#575)

* DFL H128 Mask. Mask type selectable in config.

* remove gan_v2_2

* Creating Input Size config for models

Creating Input Size config for models

Will be used downstream in converters.

Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes)

* Add mask loss options to config

* MTCNN options to config.ini. Remove GAN config. Update USAGE.md

* Add sliders for numerical values in GUI

* Add config plugins menu to gui. Validate config

* Only backup model if loss has dropped. Get training working again

* bugfixes

* Standardise loss printing

* GUI idle cpu fixes. Graph loss fix.

* mutli-gpu logging bugfix

* Merge branch 'staging' into train_refactor

* backup state file

* Crash protection: Only backup if both total losses have dropped

* Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)

* Load and save model structure with weights

* Slight code update

* Improve config loader. Add subpixel opt to all models. Config to state

* Show samples... wrong input

* Remove AE topology. Add input/output shapes to State

* Port original_villain (birb/VillainGuy) model to faceswap

* Add plugin info to GUI config pages

* Load input shape from state. IAE Config options.

* Fix transform_kwargs.
Coverage to ratio.
Bugfix mask detection

* Suppress keras userwarnings.
Automate zoom.
Coverage_ratio to model def.

* Consolidation of converters & refactor (#574)

* Consolidation of converters & refactor

Initial Upload of alpha

Items
- consolidate convert_mased & convert_adjust into one converter
-add average color adjust to convert_masked
-allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size
-allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size
-eliminate redundant type conversions to avoid multiple round-off errors
-refactor loops for vectorization/speed
-reorganize for clarity & style changes

TODO
- bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now
- issues with mask border giving black ring at zero erosion .. investigate
- remove GAN ??
- test enlargment factors of umeyama standard face .. match to coverage factor
- make enlargment factor a model parameter
- remove convert_adjusted and referencing code when finished

* Update Convert_Masked.py

default blur size of 2 to match original...
description of enlargement tests
breakout matrxi scaling into def

* Enlargment scale as a cli parameter

* Update cli.py

* dynamic interpolation algorithm

Compute x & y scale factors from the affine matrix on the fly by QR decomp.
Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image

* input size
input size from config

* fix issues with <1.0 erosion

* Update convert.py

* Update Convert_Adjust.py

more work on the way to merginf

* Clean up help note on sharpen

* cleanup seamless

* Delete Convert_Adjust.py

* Update umeyama.py

* Update training_data.py

* swapping

* segmentation stub

* changes to convert.str

* Update masked.py

* Backwards compatibility fix for models
Get converter running

* Convert:
Move masks to class.
bugfix blur_size
some linting

* mask fix

* convert fixes

- missing facehull_rect re-added
- coverage to %
- corrected coverage logic
- cleanup of gui option ordering

* Update cli.py

* default for blur

* Update masked.py

* added preliminary low_mem version of OriginalHighRes model plugin

* Code cleanup, minor fixes

* Update masked.py

* Update masked.py

* Add dfl mask to convert

* histogram fix & seamless location

* update

* revert

* bugfix: Load actual configuration in gui

* Standardize nn_blocks

* Update cli.py

* Minor code amends

* Fix Original HiRes model

* Add masks to preview output for mask trainers
refactor trainer.__base.py

* Masked trainers converter support

* convert bugfix

* Bugfix: Converter for masked (dfl/dfaker) trainers

* Additional Losses (#592)

* initial upload

* Delete blur.py

* default initializer = He instead of Glorot (#588)

* Allow kernel_initializer to be overridable

* Add ICNR Initializer option for upscale on all models.

* Hopefully fixes RSoDs with original-highres model plugin

* remove debug line

* Original-HighRes model plugin Red Screen of Death fix, take #2

* Move global options to _base. Rename Villain model

* clipnorm and res block biases

* scale the end of res block

* res block

* dfaker pre-activation res

* OHRES pre-activation

* villain pre-activation

* tabs/space in nn_blocks

* fix for histogram with mask all set to zero

* fix to prevent two networks with same name

* GUI: Wider tooltips. Improve TQDM capture

* Fix regex bug

* Convert padding=48 to ratio of image size

* Add size option to alignments tool extract

* Pass through training image size to convert from model

* Convert: Pull training coverage from model

* convert: coverage, blur and erode to percent

* simplify matrix scaling

* ordering of sliders in train

* Add matrix scaling to utils. Use interpolation in lib.aligner transform

* masked.py Import get_matrix_scaling from utils

* fix circular import

* Update masked.py

* quick fix for matrix scaling

* testing thus for now

* tqdm regex capture bugfix

* Minor ammends

* blur size cleanup

* Remove coverage option from convert (Now cascades from model)

* Implement convert for all model types

* Add mask option and coverage option to all existing models

* bugfix for model loading on convert

* debug print removal

* Bugfix for masks in dfl_h128 and iae

* Update preview display. Add preview scaling to cli

* mask notes

* Delete training_data_v2.py

errant file

* training data variables

* Fix timelapse function

* Add new config items to state file for legacy purposes

* Slight GUI tweak

* Raise exception if problem with loaded model

* Add Tensorboard support (Logs stored in model directory)

* ICNR fix

* loss bugfix

* convert bugfix

* Move ini files to config folder. Make TensorBoard optional

* Fix training data for unbalanced inputs/outputs

* Fix config "none" test

* Keep helptext in .ini files when saving config from GUI

* Remove frame_dims from alignments

* Add no-flip and warp-to-landmarks cli options

* Revert OHR to RC4_fix version

* Fix lowmem mode on OHR model

* padding to variable

* Save models in parallel threads

* Speed-up of res_block stability

* Automated Reflection Padding

* Reflect Padding as a training option

Includes auto-calculation of proper padding shapes, input_shapes, output_shapes

Flag included in config now

* rest of reflect padding

* Move TB logging to cli. Session info to state file

* Add session iterations to state file

* Add recent files to menu. GUI code tidy up

* [GUI] Fix recent file list update issue

* Add correct loss names to TensorBoard logs

* Update live graph to use TensorBoard and remove animation

* Fix analysis tab. GUI optimizations

* Analysis Graph popup to Tensorboard Logs

* [GUI] Bug fix for graphing for models with hypens in name

* [GUI] Correctly split loss to tabs during training

* [GUI] Add loss type selection to analysis graph

* Fix store command name in recent files. Switch to correct tab on open

* [GUI] Disable training graph when 'no-logs' is selected

* Fix graphing race condition

* rename original_hires model to unbalanced
2019-02-09 18:35:12 +00:00

401 lines
18 KiB
Python

#!/usr/bin/env python3
""" Process training data for model training """
import logging
from hashlib import sha1
from random import shuffle
import cv2
import numpy as np
from scipy.interpolate import griddata
from lib.model import masks
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.umeyama import umeyama
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class TrainingDataGenerator():
""" Generate training data for models """
def __init__(self, model_input_size, model_output_size, training_opts):
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
"training_opts: %s, landmarks: %s)",
self.__class__.__name__, model_input_size, model_output_size,
{key: val for key, val in training_opts.items() if key != "landmarks"},
bool(training_opts.get("landmarks", None)))
self.batchsize = 0
self.model_input_size = model_input_size
self.training_opts = training_opts
self.mask_function = self.set_mask_function()
self.landmarks = self.training_opts.get("landmarks", None)
self.processing = ImageManipulation(model_input_size,
model_output_size,
training_opts.get("coverage_ratio", 0.625))
logger.debug("Initialized %s", self.__class__.__name__)
def set_mask_function(self):
""" Set the mask function to use if using mask """
mask_type = self.training_opts.get("mask_type", None)
if mask_type:
logger.debug("Mask type: '%s'", mask_type)
mask_func = getattr(masks, mask_type)
else:
mask_func = None
logger.debug("Mask function: %s", mask_func)
return mask_func
def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False):
""" Keep a queue filled to 8x Batch Size """
logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, "
"is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse)
self.batchsize = batchsize
q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side)
q_size = batchsize * 8
# Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays
queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False)
load_thread = MultiThread(self.load_batches,
images,
q_name,
side,
is_timelapse,
do_shuffle)
load_thread.start()
logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name)
return self.minibatch(q_name, load_thread)
def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True):
""" Load the warped images and target images to queue """
logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', "
"is_timelapse: %s, do_shuffle: %s)",
len(images), q_name, side, is_timelapse, do_shuffle)
epoch = 0
queue = queue_manager.get_queue(q_name)
self.validate_samples(images)
while True:
if do_shuffle:
shuffle(images)
for img in images:
logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side)
queue.put(self.process_face(img, side, is_timelapse))
epoch += 1
logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')",
epoch, q_name, side)
def validate_samples(self, data):
""" Check the total number of images against batchsize and return
the total number of images """
length = len(data)
msg = ("Number of images is lower than batch-size (Note that too few "
"images may lead to bad training). # images: {}, "
"batch-size: {}".format(length, self.batchsize))
assert length >= self.batchsize, msg
def minibatch(self, q_name, load_thread):
""" A generator function that yields epoch, batchsize of warped_img
and batchsize of target_img from the load queue """
logger.debug("Launching minibatch generator for queue: '%s'", q_name)
queue = queue_manager.get_queue(q_name)
while True:
if load_thread.has_error:
logger.debug("Thread error detected")
break
batch = list()
for _ in range(self.batchsize):
images = queue.get()
for idx, image in enumerate(images):
if len(batch) < idx + 1:
batch.append(list())
batch[idx].append(image)
batch = [np.float32(image) for image in batch]
logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'",
len(batch), [item.shape for item in batch], q_name)
yield batch
logger.debug("Finished minibatch generator for queue: '%s'", q_name)
load_thread.join()
def process_face(self, filename, side, is_timelapse):
""" Load an image and perform transformation and warping """
logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)",
filename, side, is_timelapse)
try:
image = cv2.imread(filename) # pylint: disable=no-member
except TypeError:
raise Exception("Error while reading image", filename)
if self.mask_function or self.training_opts["warp_to_landmarks"]:
src_pts = self.get_landmarks(filename, image, side)
if self.mask_function:
image = self.mask_function(src_pts, image, channels=4)
image = self.processing.color_adjust(image)
if not is_timelapse:
image = self.processing.random_transform(image)
if not self.training_opts["no_flip"]:
image = self.processing.do_random_flip(image)
sample = image.copy()[:, :, :3]
if self.training_opts["warp_to_landmarks"]:
dst_pts = self.get_closest_match(filename, side, src_pts)
processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts)
else:
processed = self.processing.random_warp(image)
processed.insert(0, sample)
logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)",
filename, side, [img.shape for img in processed])
return processed
def get_landmarks(self, filename, image, side):
""" Return the landmarks for this face """
logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side)
lm_key = sha1(image).hexdigest()
try:
src_points = self.landmarks[side][lm_key]
except KeyError:
raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key,
filename))
logger.trace("Returning: (src_points: %s)", src_points)
return src_points
def get_closest_match(self, filename, side, src_points):
""" Return closest matched landmarks from opposite set """
logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'",
filename, src_points)
dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"]
dst_points = list(dst_points.values())
closest = (np.mean(np.square(src_points - dst_points),
axis=(1, 2))).argsort()[:10]
closest = np.random.choice(closest)
dst_points = dst_points[closest]
logger.trace("Returning: (dst_points: %s)", dst_points)
return dst_points
class ImageManipulation():
""" Manipulations to be performed on training images """
def __init__(self, input_size, output_size, coverage_ratio):
""" input_size: Size of the face input into the model
output_size: Size of the face that comes out of the modell
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
"""
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
self.__class__.__name__, input_size, output_size, coverage_ratio)
# Transform args
self.rotation_range = 10 # Range to randomly rotate the image by
self.zoom_range = 0.05 # Range to randomly zoom the image by
self.shift_range = 0.05 # Range to randomly translate the image by
self.random_flip = 0.5 # Chance to flip the image horizontally
# Transform and Warp args
self.input_size = input_size
self.output_size = output_size
# Warp args
self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160
self.scale = 5 # Normal random variable scale
logger.debug("Initialized %s", self.__class__.__name__)
@staticmethod
def color_adjust(img):
""" Color adjust RGB image """
logger.trace("Color adjusting image")
return img.astype('float32') / 255.0
@staticmethod
def separate_mask(image):
""" Return the image and the mask from a 4 channel image """
mask = None
if image.shape[2] == 4:
logger.trace("Image contains mask")
mask = np.expand_dims(image[:, :, -1], axis=2)
image = image[:, :, :3]
else:
logger.trace("Image has no mask")
return image, mask
def get_coverage(self, image):
""" Return coverage value for given image """
coverage = int(image.shape[0] * self.coverage_ratio)
logger.trace("Coverage: %s", coverage)
return coverage
def random_transform(self, image):
""" Randomly transform an image """
logger.trace("Randomly transforming image")
height, width = image.shape[0:2]
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
mat = cv2.getRotationMatrix2D( # pylint: disable=no-member
(width // 2, height // 2), rotation, scale)
mat[:, 2] += (tnx, tny)
result = cv2.warpAffine( # pylint: disable=no-member
image, mat, (width, height),
borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member
logger.trace("Randomly transformed image")
return result
def do_random_flip(self, image):
""" Perform flip on image if random number is within threshold """
logger.trace("Randomly flipping image")
if np.random.random() < self.random_flip:
logger.trace("Flip within threshold. Flipping")
retval = image[:, ::-1]
else:
logger.trace("Flip outside threshold. Not Flipping")
retval = image
logger.trace("Randomly flipped image")
return retval
def random_warp(self, image):
""" get pair of random warped images from aligned face image """
logger.trace("Randomly warping image")
height, width = image.shape[0:2]
coverage = self.get_coverage(image)
assert height == width and height % 2 == 0
range_ = np.linspace(height // 2 - coverage // 2,
height // 2 + coverage // 2,
5, dtype='float32')
mapx = np.broadcast_to(range_, (5, 5)).copy()
mapy = mapx.T
# mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast
pad = int(1.25 * self.input_size)
slices = slice(pad // 10, -pad // 10)
dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4))
interp = np.empty((2, self.input_size, self.input_size), dtype='float32')
####
for i, map_ in enumerate([mapx, mapy]):
map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale)
interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member
warped_image = cv2.remap( # pylint: disable=no-member
image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[dst_slice, dst_slice]
mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2]
target_image = cv2.warpAffine( # pylint: disable=no-member
image, mat, (self.output_size, self.output_size))
logger.trace("Target image shape: %s", target_image.shape)
warped_image, warped_mask = self.separate_mask(warped_image)
target_image, target_mask = self.separate_mask(target_image)
if target_mask is None:
logger.trace("Randomly warped image")
return [warped_image, target_image]
logger.trace("Target mask shape: %s", target_mask.shape)
logger.trace("Randomly warped image and mask")
return [warped_image, target_image, target_mask]
def random_warp_landmarks(self, image, src_points=None, dst_points=None):
""" get warped image, target image and target mask
From DFAKER plugin """
logger.trace("Randomly warping landmarks")
size = image.shape[0]
coverage = self.get_coverage(image)
p_mx = size - 1
p_hf = (size // 2) - 1
edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0),
(p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]
grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)]
source = src_points
destination = (dst_points.copy().astype('float32') +
np.random.normal(size=dst_points.shape, scale=2.0))
destination = destination.astype('uint8')
face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member
[source[17:], destination[17:]], axis=0).astype(int))
source = [(pty, ptx) for ptx, pty in source] + edge_anchors
destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors
indicies_to_remove = set()
for fpl in source, destination:
for idx, (pty, ptx) in enumerate(fpl):
if idx > 17:
break
elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member
(pty, ptx),
False) >= 0:
indicies_to_remove.add(idx)
for idx in sorted(indicies_to_remove, reverse=True):
source.pop(idx)
destination.pop(idx)
grid_z = griddata(destination, source, (grid_x, grid_y), method="linear")
map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size)
map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size)
map_x_32 = map_x.astype('float32')
map_y_32 = map_y.astype('float32')
warped_image = cv2.remap(image, # pylint: disable=no-member
map_x_32,
map_y_32,
cv2.INTER_LINEAR, # pylint: disable=no-member
cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
target_image = image
# TODO Make sure this replacement is correct
slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2)
# slices = slice(size // 32, size - size // 32) # 8px on a 256px image
warped_image = cv2.resize( # pylint: disable=no-member
warped_image[slices, slices, :], (self.input_size, self.input_size),
cv2.INTER_AREA) # pylint: disable=no-member
logger.trace("Warped image shape: %s", warped_image.shape)
target_image = cv2.resize( # pylint: disable=no-member
target_image[slices, slices, :], (self.output_size, self.output_size),
cv2.INTER_AREA) # pylint: disable=no-member
logger.trace("Target image shape: %s", target_image.shape)
warped_image, warped_mask = self.separate_mask(warped_image)
target_image, target_mask = self.separate_mask(target_image)
if target_mask is None:
logger.trace("Randomly warped image")
return [warped_image, target_image]
logger.trace("Target mask shape: %s", target_mask.shape)
logger.trace("Randomly warped image and mask")
return [warped_image, target_image, target_mask]
def stack_images(images):
""" Stack images """
logger.debug("Stack images")
def get_transpose_axes(num):
if num % 2 == 0:
logger.debug("Even number of images to stack")
y_axes = list(range(1, num - 1, 2))
x_axes = list(range(0, num - 1, 2))
else:
logger.debug("Odd number of images to stack")
y_axes = list(range(0, num - 1, 2))
x_axes = list(range(1, num - 1, 2))
return y_axes, x_axes, [num - 1]
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
logger.debug("Stacked images")
return np.transpose(
images,
axes=np.concatenate(new_axes)
).reshape(new_shape)