1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/model/losses.py
torzdf cd00859c40
model_refactor (#571) (#572)
* model_refactor (#571)

* original model to new structure

* IAE model to new structure

* OriginalHiRes to new structure

* Fix trainer for different resolutions

* Initial config implementation

* Configparse library added

* improved training data loader

* dfaker model working

* Add logging to training functions

* Non blocking input for cli training

* Add error handling to threads. Add non-mp queues to queue_handler

* Improved Model Building and NNMeta

* refactor lib/models

* training refactor. DFL H128 model Implementation

* Dfaker - use hashes

* Move timelapse. Remove perceptual loss arg

* Update INSTALL.md. Add logger formatting. Update Dfaker training

* DFL h128 partially ported

* Add mask to dfaker (#573)

* Remove old models. Add mask to dfaker

* dfl mask. Make masks selectable in config (#575)

* DFL H128 Mask. Mask type selectable in config.

* remove gan_v2_2

* Creating Input Size config for models

Creating Input Size config for models

Will be used downstream in converters.

Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes)

* Add mask loss options to config

* MTCNN options to config.ini. Remove GAN config. Update USAGE.md

* Add sliders for numerical values in GUI

* Add config plugins menu to gui. Validate config

* Only backup model if loss has dropped. Get training working again

* bugfixes

* Standardise loss printing

* GUI idle cpu fixes. Graph loss fix.

* mutli-gpu logging bugfix

* Merge branch 'staging' into train_refactor

* backup state file

* Crash protection: Only backup if both total losses have dropped

* Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)

* Load and save model structure with weights

* Slight code update

* Improve config loader. Add subpixel opt to all models. Config to state

* Show samples... wrong input

* Remove AE topology. Add input/output shapes to State

* Port original_villain (birb/VillainGuy) model to faceswap

* Add plugin info to GUI config pages

* Load input shape from state. IAE Config options.

* Fix transform_kwargs.
Coverage to ratio.
Bugfix mask detection

* Suppress keras userwarnings.
Automate zoom.
Coverage_ratio to model def.

* Consolidation of converters & refactor (#574)

* Consolidation of converters & refactor

Initial Upload of alpha

Items
- consolidate convert_mased & convert_adjust into one converter
-add average color adjust to convert_masked
-allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size
-allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size
-eliminate redundant type conversions to avoid multiple round-off errors
-refactor loops for vectorization/speed
-reorganize for clarity & style changes

TODO
- bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now
- issues with mask border giving black ring at zero erosion .. investigate
- remove GAN ??
- test enlargment factors of umeyama standard face .. match to coverage factor
- make enlargment factor a model parameter
- remove convert_adjusted and referencing code when finished

* Update Convert_Masked.py

default blur size of 2 to match original...
description of enlargement tests
breakout matrxi scaling into def

* Enlargment scale as a cli parameter

* Update cli.py

* dynamic interpolation algorithm

Compute x & y scale factors from the affine matrix on the fly by QR decomp.
Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image

* input size
input size from config

* fix issues with <1.0 erosion

* Update convert.py

* Update Convert_Adjust.py

more work on the way to merginf

* Clean up help note on sharpen

* cleanup seamless

* Delete Convert_Adjust.py

* Update umeyama.py

* Update training_data.py

* swapping

* segmentation stub

* changes to convert.str

* Update masked.py

* Backwards compatibility fix for models
Get converter running

* Convert:
Move masks to class.
bugfix blur_size
some linting

* mask fix

* convert fixes

- missing facehull_rect re-added
- coverage to %
- corrected coverage logic
- cleanup of gui option ordering

* Update cli.py

* default for blur

* Update masked.py

* added preliminary low_mem version of OriginalHighRes model plugin

* Code cleanup, minor fixes

* Update masked.py

* Update masked.py

* Add dfl mask to convert

* histogram fix & seamless location

* update

* revert

* bugfix: Load actual configuration in gui

* Standardize nn_blocks

* Update cli.py

* Minor code amends

* Fix Original HiRes model

* Add masks to preview output for mask trainers
refactor trainer.__base.py

* Masked trainers converter support

* convert bugfix

* Bugfix: Converter for masked (dfl/dfaker) trainers

* Additional Losses (#592)

* initial upload

* Delete blur.py

* default initializer = He instead of Glorot (#588)

* Allow kernel_initializer to be overridable

* Add ICNR Initializer option for upscale on all models.

* Hopefully fixes RSoDs with original-highres model plugin

* remove debug line

* Original-HighRes model plugin Red Screen of Death fix, take #2

* Move global options to _base. Rename Villain model

* clipnorm and res block biases

* scale the end of res block

* res block

* dfaker pre-activation res

* OHRES pre-activation

* villain pre-activation

* tabs/space in nn_blocks

* fix for histogram with mask all set to zero

* fix to prevent two networks with same name

* GUI: Wider tooltips. Improve TQDM capture

* Fix regex bug

* Convert padding=48 to ratio of image size

* Add size option to alignments tool extract

* Pass through training image size to convert from model

* Convert: Pull training coverage from model

* convert: coverage, blur and erode to percent

* simplify matrix scaling

* ordering of sliders in train

* Add matrix scaling to utils. Use interpolation in lib.aligner transform

* masked.py Import get_matrix_scaling from utils

* fix circular import

* Update masked.py

* quick fix for matrix scaling

* testing thus for now

* tqdm regex capture bugfix

* Minor ammends

* blur size cleanup

* Remove coverage option from convert (Now cascades from model)

* Implement convert for all model types

* Add mask option and coverage option to all existing models

* bugfix for model loading on convert

* debug print removal

* Bugfix for masks in dfl_h128 and iae

* Update preview display. Add preview scaling to cli

* mask notes

* Delete training_data_v2.py

errant file

* training data variables

* Fix timelapse function

* Add new config items to state file for legacy purposes

* Slight GUI tweak

* Raise exception if problem with loaded model

* Add Tensorboard support (Logs stored in model directory)

* ICNR fix

* loss bugfix

* convert bugfix

* Move ini files to config folder. Make TensorBoard optional

* Fix training data for unbalanced inputs/outputs

* Fix config "none" test

* Keep helptext in .ini files when saving config from GUI

* Remove frame_dims from alignments

* Add no-flip and warp-to-landmarks cli options

* Revert OHR to RC4_fix version

* Fix lowmem mode on OHR model

* padding to variable

* Save models in parallel threads

* Speed-up of res_block stability

* Automated Reflection Padding

* Reflect Padding as a training option

Includes auto-calculation of proper padding shapes, input_shapes, output_shapes

Flag included in config now

* rest of reflect padding

* Move TB logging to cli. Session info to state file

* Add session iterations to state file

* Add recent files to menu. GUI code tidy up

* [GUI] Fix recent file list update issue

* Add correct loss names to TensorBoard logs

* Update live graph to use TensorBoard and remove animation

* Fix analysis tab. GUI optimizations

* Analysis Graph popup to Tensorboard Logs

* [GUI] Bug fix for graphing for models with hypens in name

* [GUI] Correctly split loss to tabs during training

* [GUI] Add loss type selection to analysis graph

* Fix store command name in recent files. Switch to correct tab on open

* [GUI] Disable training graph when 'no-logs' is selected

* Fix graphing race condition

* rename original_hires model to unbalanced
2019-02-09 18:35:12 +00:00

844 lines
36 KiB
Python

#!/usr/bin/env python3
""" Custom Loss Functions for faceswap.py
Losses from:
keras.contrib
dfaker: https://github.com/dfaker/df
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
from __future__ import absolute_import
import keras.backend as K
from keras.layers import Lambda, concatenate
import tensorflow as tf
from tensorflow.contrib.distributions import Beta
from .normalization import InstanceNormalization
class DSSIMObjective():
""" DSSIM Loss Function
Code copy and pasted, with minor ammendments from:
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
MIT License
Copyright (c) 2017 Fariz Rahman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. """
# pylint: disable=C0103
def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0):
"""
Difference of Structural Similarity (DSSIM loss function). Clipped
between 0 and 0.5
Note : You should add a regularization term like a l2 loss in
addition to this one.
Note : In theano, the `kernel_size` must be a factor of the output
size. So 3 could not be the `kernel_size` for an output of 32.
# Arguments
k1: Parameter of the SSIM (default 0.01)
k2: Parameter of the SSIM (default 0.03)
kernel_size: Size of the sliding window (default 3)
max_value: Max value of the output (default 1.0)
"""
self.__name__ = 'DSSIMObjective'
self.kernel_size = kernel_size
self.k1 = k1
self.k2 = k2
self.max_value = max_value
self.c1 = (self.k1 * self.max_value) ** 2
self.c2 = (self.k2 * self.max_value) ** 2
self.dim_ordering = K.image_data_format()
self.backend = K.backend()
@staticmethod
def __int_shape(x):
return K.int_shape(x)
def __call__(self, y_true, y_pred):
# There are additional parameters for this function
# Note: some of the 'modes' for edge behavior do not yet have a
# gradient definition in the Theano tree and cannot be used for
# learning
kernel = [self.kernel_size, self.kernel_size]
y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:]))
y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:]))
patches_pred = self.extract_image_patches(y_pred,
kernel,
kernel,
'valid',
self.dim_ordering)
patches_true = self.extract_image_patches(y_true,
kernel,
kernel,
'valid',
self.dim_ordering)
# Reshape to get the var in the cells
_, w, h, c1, c2, c3 = self.__int_shape(patches_pred)
patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3])
patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3])
# Get mean
u_true = K.mean(patches_true, axis=-1)
u_pred = K.mean(patches_pred, axis=-1)
# Get variance
var_true = K.var(patches_true, axis=-1)
var_pred = K.var(patches_pred, axis=-1)
# Get std dev
covar_true_pred = K.mean(
patches_true * patches_pred, axis=-1) - u_true * u_pred
ssim = (2 * u_true * u_pred + self.c1) * (
2 * covar_true_pred + self.c2)
denom = (K.square(u_true) + K.square(u_pred) + self.c1) * (
var_pred + var_true + self.c2)
ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero
return K.mean((1.0 - ssim) / 2.0)
@staticmethod
def _preprocess_padding(padding):
"""Convert keras' padding to tensorflow's padding.
# Arguments
padding: string, `"same"` or `"valid"`.
# Returns
a string, `"SAME"` or `"VALID"`.
# Raises
ValueError: if `padding` is invalid.
"""
if padding == 'same':
padding = 'SAME'
elif padding == 'valid':
padding = 'VALID'
else:
raise ValueError('Invalid padding:', padding)
return padding
def extract_image_patches(self, x, ksizes, ssizes, padding='same',
data_format='channels_last'):
'''
Extract the patches from an image
# Parameters
x : The input image
ksizes : 2-d tuple with the kernel size
ssizes : 2-d tuple with the strides size
padding : 'same' or 'valid'
data_format : 'channels_last' or 'channels_first'
# Returns
The (k_w,k_h) patches extracted
TF ==> (batch_size,w,h,k_w,k_h,c)
TH ==> (batch_size,w,h,c,k_w,k_h)
'''
kernel = [1, ksizes[0], ksizes[1], 1]
strides = [1, ssizes[0], ssizes[1], 1]
padding = self._preprocess_padding(padding)
if data_format == 'channels_first':
x = K.permute_dimensions(x, (0, 2, 3, 1))
_, _, _, ch_i = K.int_shape(x)
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
padding)
# Reshaping to fit Theano
_, w, h, ch = K.int_shape(patches)
patches = tf.reshape(tf.transpose(tf.reshape(patches,
[-1, w, h,
tf.floordiv(ch, ch_i),
ch_i]),
[0, 1, 2, 4, 3]),
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
if data_format == 'channels_last':
patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
return patches
# <<< START: from Dfaker >>> #
class PenalizedLoss(): # pylint: disable=too-few-public-methods
""" Penalized Loss
from: https://github.com/dfaker/df """
def __init__(self, mask, loss_func, mask_prop=1.0):
self.mask = mask
self.loss_func = loss_func
self.mask_prop = mask_prop
self.mask_as_k_inv_prop = 1-mask_prop
def __call__(self, y_true, y_pred):
# pylint: disable=invalid-name
tro, tgo, tbo = tf.split(y_true, 3, 3)
pro, pgo, pbo = tf.split(y_pred, 3, 3)
tr = tro
tg = tgo
tb = tbo
pr = pro
pg = pgo
pb = pbo
m = self.mask
m = m * self.mask_prop
m += self.mask_as_k_inv_prop
tr *= m
tg *= m
tb *= m
pr *= m
pg *= m
pb *= m
y = tf.concat([tr, tg, tb], 3)
p = tf.concat([pr, pg, pb], 3)
# yo = tf.stack([tro,tgo,tbo],3)
# po = tf.stack([pro,pgo,pbo],3)
return self.loss_func(y, p)
# <<< END: from Dfaker >>> #
# <<< START: from Shoanlu GAN >>> #
def first_order(var_x, axis=1):
""" First Order Function from Shoanlu GAN """
img_nrows = var_x.shape[1]
img_ncols = var_x.shape[2]
if axis == 1:
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :])
if axis == 2:
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :])
return None
def calc_loss(pred, target, loss='l2'):
""" Calculate Loss from Shoanlu GAN """
if loss.lower() == "l2":
return K.mean(K.square(pred - target))
if loss.lower() == "l1":
return K.mean(K.abs(pred - target))
if loss.lower() == "cross_entropy":
return -K.mean(K.log(pred + K.epsilon()) * target +
K.log(1 - pred + K.epsilon()) * (1 - target))
raise ValueError('Recieve an unknown loss type: {}.'.format(loss))
def cyclic_loss(net_g1, net_g2, real1):
""" Cyclic Loss Function from Shoanlu GAN """
fake2 = net_g2(real1)[-1] # fake2 ABGR
fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR
cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR
cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR
loss = calc_loss(cyclic1, real1, loss='l1')
return loss
def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):
""" Adversarial Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
fake = alpha * fake_bgr + (1-alpha) * distorted
if gan_training == "mixup_LSGAN":
dist = Beta(0.2, 0.2)
lam = dist.sample()
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
pred_fake = net_d(concatenate([fake, distorted]))
pred_mixup = net_d(mixup)
loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2")
mixup2 = lam * concatenate([real,
distorted]) + (1 - lam) * concatenate([fake_bgr,
distorted])
pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
pred_mixup2 = net_d(mixup2)
loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2")
elif gan_training == "relativistic_avg_LSGAN":
real_pred = net_d(concatenate([real, distorted]))
fake_pred = net_d(concatenate([fake, distorted]))
loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2
loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2
loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred)))
fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
K.ones_like(fake_pred2)))/2
loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
K.zeros_like(fake_pred2)))/2
loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
K.zeros_like(fake_pred2)))/2
loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
K.ones_like(fake_pred2)))/2
else:
raise ValueError("Receive an unknown GAN training method: {gan_training}")
return loss_d, loss_g
def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights):
""" Reconstruction Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
loss_g = 0
loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1")
loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real)))
for out in model_outputs[:-1]:
out_size = out.get_shape().as_list()
resized_real = tf.image.resize_images(real, out_size[1:3])
loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1")
return loss_g
def edge_loss(real, fake_abgr, mask_eyes, **weights):
""" Edge Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
loss_g = 0
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1),
first_order(real, axis=1), "l1")
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2),
first_order(real, axis=2), "l1")
shape_mask_eyes = mask_eyes.get_shape().as_list()
resized_mask_eyes = tf.image.resize_images(mask_eyes,
[shape_mask_eyes[1]-1, shape_mask_eyes[2]-1])
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
(first_order(fake_bgr, axis=1) -
first_order(real, axis=1))))
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
(first_order(fake_bgr, axis=2) -
first_order(real, axis=2))))
return loss_g
def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights):
""" Perceptual Loss Function from Shoanlu GAN """
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
fake = alpha * fake_bgr + (1-alpha) * distorted
def preprocess_vggface(var_x):
var_x = (var_x + 1) / 2 * 255 # channel order: BGR
var_x -= [91.4953, 103.8827, 131.0912]
return var_x
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
dist = Beta(0.2, 0.2)
lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1.
mixup = lam*fake_bgr + (1-lam)*fake
fake_sz224 = tf.image.resize_images(mixup, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)
# Apply instance norm on VGG(ResNet) features
# From MUNIT https://github.com/NVlabs/MUNIT
loss_g = 0
def instnorm():
return InstanceNormalization()
loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
instnorm()(real_feat7), "l2")
loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
instnorm()(real_feat28), "l2")
loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
instnorm()(real_feat55), "l2")
loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
instnorm()(real_feat112), "l2")
return loss_g
# <<< END: from Shoanlu GAN >>> #
def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0):
'''
generalized function used to return a large variety of mathematical loss functions
primary benefit is smooth, differentiable version of L1 loss
Barron, J. A More General Robust Loss Function
https://arxiv.org/pdf/1701.03077.pdf
Parameters:
a: penalty factor. larger number give larger weight to large deviations
c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 )
Return:
a loss value from the results of function(y_pred - y_true)
Example:
a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss
a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss
'''
x = y_pred - y_true
loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 )
return K.mean(loss, axis=-1) * c
def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0):
h = c
w = c
x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0)
loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w))
loss += 1e-10
return K.mean(loss, axis=-1)
def gradient_loss(y_true, y_pred):
'''
Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions.
These gradients are then compared between the ground truth and the predicted image and the difference is taken.
The difference used is a smooth L1 norm ( approximate to MAE but differable at zero )
When used as a loss, its minimization will result in predicted images approaching the same level of sharpness
/ blurriness as the ground truth.
TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014
(http://downloads.hindawi.com/journals/mpe/2014/790547.pdf)
Parameters:
y_true: The predicted frames at each scale.
y_true: The ground truth frames at each scale
Return:
The GD loss.
'''
assert 4 == K.ndim(y_true)
y_true.set_shape([None,80,80,3])
y_pred.set_shape([None,80,80,3])
TV_weight = 1.0
TV2_weight = 1.0
loss = 0.0
def diff_x(X):
Xleft = X[:, :, 1, :] - X[:, :, 0, :]
Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2)
Xright = X[:, :, -1, :] - X[:, :, -2, :]
Xout = [Xleft] + Xinner + [Xright]
Xout = tf.stack(Xout,axis=2)
return Xout * 0.5
def diff_y(X):
Xtop = X[:, 1, :, :] - X[:, 0, :, :]
Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1)
Xbot = X[:, -1, :, :] - X[:, -2, :, :]
Xout = [Xtop] + Xinner + [Xbot]
Xout = tf.stack(Xout,axis=1)
return Xout * 0.5
def diff_xx(X):
Xleft = X[:, :, 1, :] + X[:, :, 0, :]
Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2)
Xright = X[:, :, -1, :] + X[:, :, -2, :]
Xout = [Xleft] + Xinner + [Xright]
Xout = tf.stack(Xout,axis=2)
return Xout - 2.0 * X
def diff_yy(X):
Xtop = X[:, 1, :, :] + X[:, 0, :, :]
Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1)
Xbot = X[:, -1, :, :] + X[:, -2, :, :]
Xout = [Xtop] + Xinner + [Xbot]
Xout = tf.stack(Xout,axis=1)
return Xout - 2.0 * X
def diff_xy(X):
#xout1
top_left = X[:, 1, 1, :]+X[:, 0, 0, :]
inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1)
bot_left = X[:, -1, 1, :]+X[:, -2, 0, :]
X_left = [top_left] + inner_left + [bot_left]
X_left = tf.stack(X_left, axis=1)
top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :]
mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1)
bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :]
X_mid = [top_mid] + mid_mid + [bot_mid]
X_mid = tf.stack(X_mid, axis=1)
top_right = X[:, 1, -1, :]+X[:, 0, -2, :]
inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1)
bot_right = X[:, -1, -1, :]+X[:, -2, -2, :]
X_right = [top_right] + inner_right + [bot_right]
X_right = tf.stack(X_right, axis=1)
X_mid = tf.unstack(X_mid, axis=2)
Xout1 = [X_left] + X_mid + [X_right]
Xout1 = tf.stack(Xout1, axis=2)
#Xout2
top_left = X[:, 0, 1, :]+X[:, 1, 0, :]
inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1)
bot_left = X[:, -2, 1, :]+X[:, -1, 0, :]
X_left = [top_left] + inner_left + [bot_left]
X_left = tf.stack(X_left, axis=1)
top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :]
mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1)
bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :]
X_mid = [top_mid] + mid_mid + [bot_mid]
X_mid = tf.stack(X_mid, axis=1)
top_right = X[:, 0, -1, :]+X[:, 1, -2, :]
inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1)
bot_right = X[:, -2, -1, :]+X[:, -1, -2, :]
X_right = [top_right] + inner_right + [bot_right]
X_right = tf.stack(X_right, axis=1)
X_mid = tf.unstack(X_mid, axis=2)
Xout2 = [X_left] + X_mid + [X_right]
Xout2 = tf.stack(Xout2, axis=2)
return (Xout1 - Xout2) * 0.25
loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) +
generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) )
loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) +
generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) +
2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) )
return loss / ( TV_weight + TV2_weight )
def scharr_edges(image, magnitude):
'''
Returns a tensor holding modified Scharr edge maps.
Arguments:
image: Image tensor with shape [batch_size, h, w, d] and type float32.
The image(s) must be 2x2 or larger.
magnitude: Boolean to determine if the edge magnitude or edge direction is returned
Returns:
Tensor holding edge maps for each channel. Returns a tensor with shape
[batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]],
[dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter.
'''
# Define vertical and horizontal Scharr filters.
static_image_shape = image.get_shape()
image_shape = tf.shape(image)
'''
#modified 3x3 Scharr
kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]],
[[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]]
'''
# 5x5 Scharr
kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]],
[[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]]
num_kernels = len(kernels)
kernels = numpy.transpose(numpy.asarray(kernels), (1, 2, 0))
kernels = numpy.expand_dims(kernels, -2) / numpy.sum(numpy.abs(kernels))
kernels_tf = tf.constant(kernels, dtype=image.dtype)
kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters')
# Use depth-wise convolution to calculate edge maps per channel.
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
padded = tf.pad(image, pad_sizes, mode='REFLECT')
# Output tensor has shape [batch_size, h, w, d * num_kernels].
strides = [1, 1, 1, 1]
output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID')
# Reshape to [batch_size, h, w, d, num_kernels].
shape = tf.concat([image_shape, [num_kernels]], 0)
output = tf.reshape(output, shape=shape)
output.set_shape(static_image_shape.concatenate([num_kernels]))
if magnitude: # magnitude of edges
output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1))
else: # direction of edges
output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1])))
return output
def gmsd_loss(y_true,y_pred):
'''
Improved image quality metric over MS-SSIM with easier calc
http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm
https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf
'''
true_edge_mag = scharr_edges(y_true,True)
pred_edge_mag = scharr_edges(y_pred,True)
c = 0.002
upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c
lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c
GMS = tf.div(upper,lower)
_mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True)
GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1]
return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded
def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)):
'''
Computes the MS-SSIM between img1 and img2.
This function assumes that `img1` and `img2` are image batches, i.e. the last
three dimensions are [height, width, channels].
Note: The true SSIM is only defined on grayscale. This function does not
perform any colorspace transform. (If input is already YUV, then it will
compute YUV SSIM average.)
Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
structural similarity for image quality assessment." Signals, Systems and
Computers, 2004.
Arguments:
img1: First image batch.
img2: Second image batch. Must have the same rank as img1.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
power_factors: Iterable of weights for each of the scales. The number of
scales used is the length of the list. Index 0 is the unscaled
resolution's weight and each increasing scale corresponds to the image
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
0.1333), which are the values obtained in the original paper.
Returns:
A tensor containing an MS-SSIM value for each image in batch. The values
are in range [0, 1]. Returns a tensor with shape:
broadcast(img1.shape[:-3], img2.shape[:-3]).
'''
def _verify_compatible_image_shapes(img1, img2):
'''
Checks if two image tensors are compatible for applying SSIM or PSNR.
This function checks if two sets of images have ranks at least 3, and if the
last three dimensions match.
Args:
img1: Tensor containing the first image batch.
img2: Tensor containing the second image batch.
Returns:
A tuple containing: the first tensor shape, the second tensor shape, and a
list of control_flow_ops.Assert() ops implementing the checks.
Raises:
ValueError: When static shape check fails.
'''
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].assert_is_compatible_with(shape2[-3:])
if shape1.ndims is not None and shape2.ndims is not None:
for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])):
if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2))
# Now assign shape tensors.
shape1, shape2 = tf.shape_n([img1, img2])
# TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable.
checks = []
checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10))
checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10))
return shape1, shape2, checks
def _ssim_per_channel(img1, img2, max_val=1.0):
'''
Computes SSIM index between img1 and img2 per color channel.
This function matches the standard SSIM implementation from:
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
quality assessment: from error visibility to structural similarity. IEEE
transactions on image processing.
Details:
- 11x11 Gaussian filter of width 1.5 is used.
- k1 = 0.01, k2 = 0.03 as in the original paper.
Args:
img1: First image batch.
img2: Second image batch.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
Returns:
A pair of tensors containing and channel-wise SSIM and contrast-structure
values. The shape is [..., channels].
'''
def _fspecial_gauss(size, sigma):
'''
Function to mimic the 'fspecial' gaussian MATLAB function.
'''
size = tf.convert_to_tensor(size, 'int32')
sigma = tf.convert_to_tensor(sigma)
coords = tf.cast(tf.range(size), sigma.dtype)
coords -= tf.cast(size - 1, sigma.dtype) / 2.0
g = tf.square(coords)
g *= -0.5 / tf.square(sigma)
g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1])
g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
g = tf.nn.softmax(g)
return tf.reshape(g, shape=[size, size, 1, 1])
def _ssim_helper(x, y, max_val, kernel, compensation=1.0):
'''
Helper function for computing SSIM.
SSIM estimates covariances with weighted sums. The default parameters
use a biased estimate of the covariance:
Suppose `reducer` is a weighted sum, then the mean estimators are
\mu_x = \sum_i w_i x_i,
\mu_y = \sum_i w_i y_i,
where w_i's are the weighted-sum weights, and covariance estimator is
cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
with assumption \sum_i w_i = 1. This covariance estimator is biased, since
E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
For SSIM measure with unbiased covariance estimators, pass as `compensation`
argument (1 - \sum_i w_i ^ 2).
Arguments:
x: First set of images.
y: Second set of images.
reducer: Function that computes 'local' averages from set of images.
For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]),
and for convolutional version, this is usually tf.nn.avg_pool or
tf.nn.conv2d with weighted-sum kernel.
max_val: The dynamic range (i.e., the difference between the maximum
possible allowed value and the minimum allowed value).
compensation: Compensation factor. See above.
Returns:
A pair containing the luminance measure, and the contrast-structure measure.
'''
def reducer(x, kernel):
shape = tf.shape(x)
x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0))
y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0))
_SSIM_K1 = 0.01
_SSIM_K2 = 0.03
c1 = (_SSIM_K1 * max_val) ** 2
c2 = (_SSIM_K2 * max_val) ** 2
# SSIM luminance measure is
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
mean0 = reducer(x, kernel)
mean1 = reducer(y, kernel)
num0 = mean0 * mean1 * 2.0
den0 = tf.square(mean0) + tf.square(mean1)
luminance = (num0 + c1) / (den0 + c1)
# SSIM contrast-structure measure is
# (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
# Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
# cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
# = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
num1 = reducer(x * y, kernel) * 2.0
den1 = reducer(tf.square(x) + tf.square(y), kernel)
c2 *= compensation
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
# SSIM score is the product of the luminance and contrast-structure measures.
return luminance, cs
filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due
filter_sigma = tf.constant(1.5, dtype=img1.dtype)
shape1, shape2 = tf.shape_n([img1, img2])
checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8),
tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)]
# Enforce the check to run before computation.
with tf.control_dependencies(checks):
img1 = tf.identity(img1)
# TODO(sjhwang): Try to cache kernels and compensation factor.
kernel = _fspecial_gauss(filter_size, filter_sigma)
kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1])
# The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
# but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
compensation = 1.0
# TODO(sjhwang): Try FFT.
# TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
# 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter.
luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation)
# Average over the second and the third from the last: height, width.
axes = tf.constant([-3, -2], dtype='int32')
ssim_val = tf.reduce_mean(luminance * cs, axes)
cs = tf.reduce_mean(cs, axes)
return ssim_val, cs
def do_pad(images, remainder):
padding = tf.expand_dims(remainder, -1)
padding = tf.pad(padding, [[1, 0], [1, 0]])
return [tf.pad(x, padding, mode='SYMMETRIC') for x in images]
# Shape checking.
shape1 = img1.get_shape().with_rank_at_least(3)
shape2 = img2.get_shape().with_rank_at_least(3)
shape1[-3:].merge_with(shape2[-3:])
with tf.name_scope(None, 'MS-SSIM', [img1, img2]):
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
with tf.control_dependencies(checks):
img1 = tf.identity(img1)
# Need to convert the images to float32. Scale max_val accordingly so that
# SSIM is computed correctly.
max_val = tf.cast(max_val, img1.dtype)
max_val = tf.image.convert_image_dtype(max_val, 'float32')
img1 = tf.image.convert_image_dtype(img1, 'float32')
img2 = tf.image.convert_image_dtype(img2, 'float32')
imgs = [img1, img2]
shapes = [shape1, shape2]
# img1 and img2 are assumed to be a (multi-dimensional) batch of
# 3-dimensional images (height, width, channels). `heads` contain the batch
# dimensions, and `tails` contain the image dimensions.
heads = [s[:-3] for s in shapes]
tails = [s[-3:] for s in shapes]
divisor = [1, 2, 2, 1]
divisor_tensor = tf.constant(divisor[1:], dtype='int32')
mcs = []
for k in range(len(power_factors)):
with tf.name_scope(None, 'Scale%d' % k, imgs):
if k > 0:
# Avg pool takes rank 4 tensors. Flatten leading dimensions.
flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)]
remainder = tails[0] % divisor_tensor
need_padding = tf.reduce_any(tf.not_equal(remainder, 0))
padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder),
lambda: flat_imgs)
downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID')
for x in padded]
tails = [x[1:] for x in tf.shape_n(downscaled)]
imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)]
# Overwrite previous ssim value since we only need the last one.
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val)
mcs.append(tf.nn.relu(cs))
# Remove the cs score for the last scale. In the MS-SSIM calculation,
# we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
mcs.pop() # Remove the cs score for the last scale.
mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1)
# Take weighted geometric mean across the scale axis.
ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1])
return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
def ms_ssim_loss(y_true,y_pred):
MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1)
return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded