1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/extract/align/fan.py

232 lines
11 KiB
Python

#!/usr/bin/env python3
""" Facial landmarks extractor for faceswap.py
Code adapted and modified from:
https://github.com/1adrianb/face-alignment
"""
import cv2
import numpy as np
import keras
from keras import backend as K
from lib.model.session import KSession
from ._base import Aligner, logger
class Align(Aligner):
""" Perform transformation to align and get landmarks """
def __init__(self, **kwargs):
git_model_id = 9
model_filename = "face-alignment-network_2d4_keras_v1.h5"
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
self.name = "FAN"
self.input_size = 256
self.color_format = "RGB"
self.vram = 2240
self.vram_warnings = 512 # Will run at this with warnings
self.vram_per_batch = 64
self.batchsize = self.config["batch-size"]
self.reference_scale = 200. / 195.
def init_model(self):
""" Initialize FAN model """
model_kwargs = dict(custom_objects={'TorchBatchNorm2D': TorchBatchNorm2D})
self.model = KSession(self.name,
self.model_path,
model_kwargs=model_kwargs,
allow_growth=self.config["allow_growth"])
self.model.load_model()
# Feed a placeholder so Aligner is primed for Manual tool
placeholder_shape = (self.batchsize, 3, self.input_size, self.input_size)
placeholder = np.zeros(placeholder_shape, dtype="float32")
self.model.predict(placeholder)
def process_input(self, batch):
""" Compile the detected faces for prediction """
logger.debug("Aligning faces around center")
batch["center_scale"] = self.get_center_scale(batch["detected_faces"])
faces = self.crop(batch)
logger.trace("Aligned image around center")
faces = self._normalize_faces(faces)
batch["feed"] = np.array(faces, dtype="float32")[..., :3].transpose((0, 3, 1, 2)) / 255.0
return batch
def get_center_scale(self, detected_faces):
""" Get the center and set scale of bounding box """
logger.debug("Calculating center and scale")
center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32')
for index, face in enumerate(detected_faces):
x_center = (face.left + face.right) / 2.0
y_center = (face.top + face.bottom) / 2.0 - face.h * 0.12
scale = (face.w + face.h) * self.reference_scale
center_scale[index, :, 0] = np.full(68, x_center, dtype='float32')
center_scale[index, :, 1] = np.full(68, y_center, dtype='float32')
center_scale[index, :, 2] = np.full(68, scale, dtype='float32')
logger.trace("Calculated center and scale: %s", center_scale)
return center_scale
def crop(self, batch): # pylint:disable=too-many-locals
""" Crop image around the center point """
logger.debug("Cropping images")
sizes = (self.input_size, self.input_size)
batch_shape = batch["center_scale"].shape[:2]
resolutions = np.full(batch_shape, self.input_size, dtype='float32')
matrix_ones = np.ones(batch_shape + (3,), dtype='float32')
matrix_size = np.full(batch_shape + (3,), self.input_size, dtype='float32')
matrix_size[..., 2] = 1.0
upper_left = self.transform(matrix_ones, batch["center_scale"], resolutions)
bot_right = self.transform(matrix_size, batch["center_scale"], resolutions)
# TODO second pass .. convert to matrix
new_images = []
for image, top_left, bottom_right in zip(batch["image"], upper_left, bot_right):
height, width = image.shape[:2]
channels = 3 if image.ndim > 2 else 1
bottom_right_width, bottom_right_height = bottom_right[0].astype('int32')
top_left_width, top_left_height = top_left[0].astype('int32')
new_dim = (bottom_right_height - top_left_height,
bottom_right_width - top_left_width,
channels)
new_img = np.empty(new_dim, dtype=np.uint8)
new_x = slice(max(0, -top_left_width), min(bottom_right_width, width) - top_left_width)
new_y = slice(max(0, -top_left_height),
min(bottom_right_height, height) - top_left_height)
old_x = slice(max(0, top_left_width), min(bottom_right_width, width))
old_y = slice(max(0, top_left_height), min(bottom_right_height, height))
new_img[new_y, new_x] = image[old_y, old_x]
interp = cv2.INTER_CUBIC if new_dim[0] < self.input_size else cv2.INTER_AREA
new_images.append(cv2.resize(new_img, dsize=sizes, interpolation=interp))
logger.trace("Cropped images")
return new_images
@staticmethod
def transform(points, center_scales, resolutions):
""" Transform Image """
logger.debug("Transforming Points")
num_images, num_landmarks = points.shape[:2]
transform_matrix = np.eye(3, dtype='float32')
transform_matrix = np.repeat(transform_matrix[None, :], num_landmarks, axis=0)
transform_matrix = np.repeat(transform_matrix[None, :, :], num_images, axis=0)
scales = center_scales[:, :, 2] / resolutions
translations = center_scales[..., 2:3] * -0.5 + center_scales[..., :2]
transform_matrix[:, :, 0, 0] = scales # x scale
transform_matrix[:, :, 1, 1] = scales # y scale
transform_matrix[:, :, 0, 2] = translations[:, :, 0] # x translation
transform_matrix[:, :, 1, 2] = translations[:, :, 1] # y translation
new_points = np.einsum('abij, abj -> abi', transform_matrix, points, optimize='greedy')
retval = new_points[:, :, :2].astype('float32')
logger.trace("Transformed Points: %s", retval)
return retval
def predict(self, batch):
""" Predict the 68 point landmarks """
logger.debug("Predicting Landmarks")
batch["prediction"] = self.model.predict(batch["feed"])[-1]
logger.trace([pred.shape for pred in batch["prediction"]])
return batch
def process_output(self, batch):
""" Process the output from the model """
self.get_pts_from_predict(batch)
return batch
def get_pts_from_predict(self, batch):
""" Get points from predictor """
logger.debug("Obtain points from prediction")
num_images, num_landmarks, height, width = batch["prediction"].shape
image_slice = np.repeat(np.arange(num_images)[:, None], num_landmarks, axis=1)
landmark_slice = np.repeat(np.arange(num_landmarks)[None, :], num_images, axis=0)
resolution = np.full((num_images, num_landmarks), 64, dtype='int32')
subpixel_landmarks = np.ones((num_images, num_landmarks, 3), dtype='float32')
flat_indices = batch["prediction"].reshape(num_images, num_landmarks, -1).argmax(-1)
indices = np.array(np.unravel_index(flat_indices, (height, width)))
min_clipped = np.minimum(indices + 1, height - 1)
max_clipped = np.maximum(indices - 1, 0)
offsets = [(image_slice, landmark_slice, indices[0], min_clipped[1]),
(image_slice, landmark_slice, indices[0], max_clipped[1]),
(image_slice, landmark_slice, min_clipped[0], indices[1]),
(image_slice, landmark_slice, max_clipped[0], indices[1])]
x_subpixel_shift = batch["prediction"][offsets[0]] - batch["prediction"][offsets[1]]
y_subpixel_shift = batch["prediction"][offsets[2]] - batch["prediction"][offsets[3]]
# TODO improve rudimentary sub-pixel logic to centroid of 3x3 window algorithm
subpixel_landmarks[:, :, 0] = indices[1] + np.sign(x_subpixel_shift) * 0.25 + 0.5
subpixel_landmarks[:, :, 1] = indices[0] + np.sign(y_subpixel_shift) * 0.25 + 0.5
batch["landmarks"] = self.transform(subpixel_landmarks, batch["center_scale"], resolution)
logger.trace("Obtained points from prediction: %s", batch["landmarks"])
class TorchBatchNorm2D(keras.engine.base_layer.Layer):
# pylint:disable=too-many-instance-attributes
"""" Required for FAN_keras model """
def __init__(self, axis=-1, momentum=0.99, epsilon=1e-3, **kwargs):
super(TorchBatchNorm2D, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self._epsilon_const = K.constant(self.epsilon, dtype='float32')
self.built = False
self.gamma = None
self.beta = None
self.moving_mean = None
self.moving_variance = None
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError("Axis {} of input tensor should have a "
"defined dimension but the layer received "
"an input with shape {}."
.format(str(self.axis), str(input_shape)))
shape = (dim,)
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer='ones',
regularizer=None,
constraint=None)
self.beta = self.add_weight(shape=shape,
name='beta',
initializer='zeros',
regularizer=None,
constraint=None)
self.moving_mean = self.add_weight(shape=shape,
name='moving_mean',
initializer='zeros',
trainable=False)
self.moving_variance = self.add_weight(shape=shape,
name='moving_variance',
initializer='ones',
trainable=False)
self.built = True
def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape)
broadcast_moving_variance = K.reshape(self.moving_variance,
broadcast_shape)
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
broadcast_beta = K.reshape(self.beta, broadcast_shape)
invstd = (
K.ones(shape=broadcast_shape, dtype='float32')
/ K.sqrt(broadcast_moving_variance + self._epsilon_const)
)
return((inputs - broadcast_moving_mean)
* invstd
* broadcast_gamma
+ broadcast_beta)
def get_config(self):
config = {'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon}
base_config = super(TorchBatchNorm2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))