1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/lib/convert.py
torzdf 43a4d06540
Smart Masks - Training Implementation (#914)
* Smart Masks - Training

- Reinstate smart mask training code
- Reinstate mask_type back to model.config
- change 'replicate_input_mask to 'learn_mask'
- Add learn mask option
- Add mask loading from alignments to plugins.train.trainer
- Add mask_blur and mask threshold options
- _base.py - Pass mask options through training_opts dict
- plugins.train.model - check for mask_type not None for learn_mask and penalized_mask_loss
- Limit alignments loading to just those faces that appear in the training folder
- Raise error if not all training images have an alignment, and alignment file is required
- lib.training_data - Mask generation code
- lib.faces_detect - cv2 dimension stripping bugfix
- Remove cv2 linting code

* Update mask helptext in cli.py

* Fix Warp to Landmarks
Remove SHA1 hashing from training data

* Update mask training config

* Capture missing masks at training init

* lib.image.read_image_batch - Return filenames with batch for ordering

* scripts.train - Documentation

* plugins.train.trainer - documentation

* Ensure backward compatibility.
Fix convert for new predicted masks

* Update removed masks to components for legacy models.
2019-12-05 16:02:01 +00:00

217 lines
10 KiB
Python

#!/usr/bin/env python3
""" Converter for faceswap.py
Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955
found on https://www.reddit.com/r/deepfakes/ """
import logging
import cv2
import numpy as np
from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Converter():
""" Swap a source face with a target """
def __init__(self, output_dir, output_size, output_has_mask,
draw_transparent, pre_encode, arguments, configfile=None):
logger.debug("Initializing %s: (output_dir: '%s', output_size: %s, output_has_mask: %s, "
"draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)",
self.__class__.__name__, output_dir, output_size, output_has_mask,
draw_transparent, pre_encode, arguments, configfile)
self.output_dir = output_dir
self.draw_transparent = draw_transparent
self.writer_pre_encode = pre_encode
self.scale = arguments.output_scale / 100
self.output_size = output_size
self.output_has_mask = output_has_mask
self.args = arguments
self.configfile = configfile
self.adjustments = dict(box=None, mask=None, color=None, seamless=None, scaling=None)
self.load_plugins()
logger.debug("Initialized %s", self.__class__.__name__)
def reinitialize(self, config):
""" reinitialize converter """
logger.debug("Reinitializing converter")
self.adjustments = dict(box=None, mask=None, color=None, seamless=None, scaling=None)
self.load_plugins(config=config, disable_logging=True)
logger.debug("Reinitialized converter")
def load_plugins(self, config=None, disable_logging=False):
""" Load the requested adjustment plugins """
logger.debug("Loading plugins. config: %s", config)
self.adjustments["box"] = PluginLoader.get_converter(
"mask",
"box_blend",
disable_logging=disable_logging)("none",
self.output_size,
configfile=self.configfile,
config=config)
self.adjustments["mask"] = PluginLoader.get_converter(
"mask",
"mask_blend",
disable_logging=disable_logging)(self.args.mask_type,
self.output_size,
self.output_has_mask,
configfile=self.configfile,
config=config)
if self.args.color_adjustment != "none" and self.args.color_adjustment is not None:
self.adjustments["color"] = PluginLoader.get_converter(
"color",
self.args.color_adjustment,
disable_logging=disable_logging)(configfile=self.configfile, config=config)
if self.args.scaling != "none" and self.args.scaling is not None:
self.adjustments["scaling"] = PluginLoader.get_converter(
"scaling",
self.args.scaling,
disable_logging=disable_logging)(configfile=self.configfile, config=config)
logger.debug("Loaded plugins: %s", self.adjustments)
def process(self, in_queue, out_queue, completion_queue=None):
""" Process items from the queue """
logger.debug("Starting convert process. (in_queue: %s, out_queue: %s, completion_queue: "
"%s)", in_queue, out_queue, completion_queue)
while True:
items = in_queue.get()
if items == "EOF":
logger.debug("EOF Received")
logger.debug("Patch queue finished")
# Signal EOF to other processes in pool
logger.debug("Putting EOF back to in_queue")
in_queue.put(items)
break
if isinstance(items, dict):
items = [items]
for item in items:
logger.trace("Patch queue got: '%s'", item["filename"])
try:
image = self.patch_image(item)
except Exception as err: # pylint: disable=broad-except
# Log error and output original frame
logger.error("Failed to convert image: '%s'. Reason: %s",
item["filename"], str(err))
image = item["image"]
# UNCOMMENT THIS CODE BLOCK TO PRINT TRACEBACK ERRORS
# import sys ; import traceback
# exc_info = sys.exc_info() ; traceback.print_exception(*exc_info)
logger.trace("Out queue put: %s", item["filename"])
out_queue.put((item["filename"], image))
logger.debug("Completed convert process")
# Signal that this process has finished
if completion_queue is not None:
completion_queue.put(1)
def patch_image(self, predicted):
""" Patch the image """
logger.trace("Patching image: '%s'", predicted["filename"])
frame_size = (predicted["image"].shape[1], predicted["image"].shape[0])
new_image, background = self.get_new_image(predicted, frame_size)
patched_face = self.post_warp_adjustments(background, new_image)
patched_face = self.scale_image(patched_face)
patched_face *= 255.0
patched_face = np.rint(
patched_face,
out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe'
)
if self.writer_pre_encode is not None:
patched_face = self.writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted["filename"])
return patched_face
def get_new_image(self, predicted, frame_size):
""" Get the new face from the predictor and apply box manipulations """
logger.trace("Getting: (filename: '%s', faces: %s)",
predicted["filename"], len(predicted["swapped_faces"]))
placeholder = np.zeros((frame_size[1], frame_size[0], 4), dtype="float32")
background = predicted["image"] / np.array(255.0, dtype="float32")
placeholder[:, :, :3] = background
for new_face, detected_face in zip(predicted["swapped_faces"],
predicted["detected_faces"]):
predicted_mask = new_face[:, :, -1] if new_face.shape[2] == 4 else None
new_face = new_face[:, :, :3]
src_face = detected_face.reference_face[..., :3] / 255.0
interpolator = detected_face.reference_interpolators[1]
new_face = self.pre_warp_adjustments(src_face, new_face, detected_face, predicted_mask)
# Warp face with the mask
cv2.warpAffine( # pylint: disable=no-member
new_face,
detected_face.reference_matrix,
frame_size,
placeholder,
flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member
borderMode=cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
np.clip(placeholder, 0.0, 1.0, out=placeholder)
logger.trace("Got filename: '%s'. (placeholders: %s)",
predicted["filename"], placeholder.shape)
return placeholder, background
def pre_warp_adjustments(self, old_face, new_face, detected_face, predicted_mask):
""" Run the pre-warp adjustments """
logger.trace("old_face shape: %s, new_face shape: %s, predicted_mask shape: %s",
old_face.shape, new_face.shape,
predicted_mask.shape if predicted_mask is not None else None)
new_face = self.adjustments["box"].run(new_face)
new_face, raw_mask = self.get_image_mask(new_face, detected_face, predicted_mask)
if self.adjustments["color"] is not None:
new_face = self.adjustments["color"].run(old_face, new_face, raw_mask)
if self.adjustments["seamless"] is not None:
new_face = self.adjustments["seamless"].run(old_face, new_face, raw_mask)
logger.trace("returning: new_face shape %s", new_face.shape)
return new_face
def get_image_mask(self, new_face, detected_face, predicted_mask):
""" Get the image mask """
logger.trace("Getting mask. Image shape: %s", new_face.shape)
mask, raw_mask = self.adjustments["mask"].run(detected_face, predicted_mask)
if new_face.shape[2] == 4:
logger.trace("Combining mask with alpha channel box mask")
new_face[:, :, -1] = np.minimum(new_face[:, :, -1], mask.squeeze())
else:
logger.trace("Adding mask to alpha channel")
new_face = np.concatenate((new_face, mask), -1)
np.clip(new_face, 0.0, 1.0, out=new_face)
logger.trace("Got mask. Image shape: %s", new_face.shape)
return new_face, raw_mask
def post_warp_adjustments(self, background, new_image):
""" Apply fixes to the image after warping """
if self.adjustments["scaling"] is not None:
new_image = self.adjustments["scaling"].run(new_image)
if self.draw_transparent:
frame = new_image
else:
foreground, mask = np.split(new_image, (3, ), axis=-1)
foreground *= mask
background *= (1.0 - mask)
background += foreground
frame = background
np.clip(frame, 0.0, 1.0, out=frame)
return frame
def scale_image(self, frame):
""" Scale the image if requested """
if self.scale == 1:
return frame
logger.trace("source frame: %s", frame.shape)
interp = cv2.INTER_CUBIC if self.scale > 1 else cv2.INTER_AREA # pylint: disable=no-member
dims = (round((frame.shape[1] / 2 * self.scale) * 2),
round((frame.shape[0] / 2 * self.scale) * 2))
frame = cv2.resize(frame, dims, interpolation=interp) # pylint: disable=no-member
logger.trace("resized frame: %s", frame.shape)
np.clip(frame, 0.0, 1.0, out=frame)
return frame