mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* model_refactor (#571) * original model to new structure * IAE model to new structure * OriginalHiRes to new structure * Fix trainer for different resolutions * Initial config implementation * Configparse library added * improved training data loader * dfaker model working * Add logging to training functions * Non blocking input for cli training * Add error handling to threads. Add non-mp queues to queue_handler * Improved Model Building and NNMeta * refactor lib/models * training refactor. DFL H128 model Implementation * Dfaker - use hashes * Move timelapse. Remove perceptual loss arg * Update INSTALL.md. Add logger formatting. Update Dfaker training * DFL h128 partially ported * Add mask to dfaker (#573) * Remove old models. Add mask to dfaker * dfl mask. Make masks selectable in config (#575) * DFL H128 Mask. Mask type selectable in config. * remove gan_v2_2 * Creating Input Size config for models Creating Input Size config for models Will be used downstream in converters. Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes) * Add mask loss options to config * MTCNN options to config.ini. Remove GAN config. Update USAGE.md * Add sliders for numerical values in GUI * Add config plugins menu to gui. Validate config * Only backup model if loss has dropped. Get training working again * bugfixes * Standardise loss printing * GUI idle cpu fixes. Graph loss fix. * mutli-gpu logging bugfix * Merge branch 'staging' into train_refactor * backup state file * Crash protection: Only backup if both total losses have dropped * Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes) * Load and save model structure with weights * Slight code update * Improve config loader. Add subpixel opt to all models. Config to state * Show samples... wrong input * Remove AE topology. Add input/output shapes to State * Port original_villain (birb/VillainGuy) model to faceswap * Add plugin info to GUI config pages * Load input shape from state. IAE Config options. * Fix transform_kwargs. Coverage to ratio. Bugfix mask detection * Suppress keras userwarnings. Automate zoom. Coverage_ratio to model def. * Consolidation of converters & refactor (#574) * Consolidation of converters & refactor Initial Upload of alpha Items - consolidate convert_mased & convert_adjust into one converter -add average color adjust to convert_masked -allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size -allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size -eliminate redundant type conversions to avoid multiple round-off errors -refactor loops for vectorization/speed -reorganize for clarity & style changes TODO - bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now - issues with mask border giving black ring at zero erosion .. investigate - remove GAN ?? - test enlargment factors of umeyama standard face .. match to coverage factor - make enlargment factor a model parameter - remove convert_adjusted and referencing code when finished * Update Convert_Masked.py default blur size of 2 to match original... description of enlargement tests breakout matrxi scaling into def * Enlargment scale as a cli parameter * Update cli.py * dynamic interpolation algorithm Compute x & y scale factors from the affine matrix on the fly by QR decomp. Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image * input size input size from config * fix issues with <1.0 erosion * Update convert.py * Update Convert_Adjust.py more work on the way to merginf * Clean up help note on sharpen * cleanup seamless * Delete Convert_Adjust.py * Update umeyama.py * Update training_data.py * swapping * segmentation stub * changes to convert.str * Update masked.py * Backwards compatibility fix for models Get converter running * Convert: Move masks to class. bugfix blur_size some linting * mask fix * convert fixes - missing facehull_rect re-added - coverage to % - corrected coverage logic - cleanup of gui option ordering * Update cli.py * default for blur * Update masked.py * added preliminary low_mem version of OriginalHighRes model plugin * Code cleanup, minor fixes * Update masked.py * Update masked.py * Add dfl mask to convert * histogram fix & seamless location * update * revert * bugfix: Load actual configuration in gui * Standardize nn_blocks * Update cli.py * Minor code amends * Fix Original HiRes model * Add masks to preview output for mask trainers refactor trainer.__base.py * Masked trainers converter support * convert bugfix * Bugfix: Converter for masked (dfl/dfaker) trainers * Additional Losses (#592) * initial upload * Delete blur.py * default initializer = He instead of Glorot (#588) * Allow kernel_initializer to be overridable * Add ICNR Initializer option for upscale on all models. * Hopefully fixes RSoDs with original-highres model plugin * remove debug line * Original-HighRes model plugin Red Screen of Death fix, take #2 * Move global options to _base. Rename Villain model * clipnorm and res block biases * scale the end of res block * res block * dfaker pre-activation res * OHRES pre-activation * villain pre-activation * tabs/space in nn_blocks * fix for histogram with mask all set to zero * fix to prevent two networks with same name * GUI: Wider tooltips. Improve TQDM capture * Fix regex bug * Convert padding=48 to ratio of image size * Add size option to alignments tool extract * Pass through training image size to convert from model * Convert: Pull training coverage from model * convert: coverage, blur and erode to percent * simplify matrix scaling * ordering of sliders in train * Add matrix scaling to utils. Use interpolation in lib.aligner transform * masked.py Import get_matrix_scaling from utils * fix circular import * Update masked.py * quick fix for matrix scaling * testing thus for now * tqdm regex capture bugfix * Minor ammends * blur size cleanup * Remove coverage option from convert (Now cascades from model) * Implement convert for all model types * Add mask option and coverage option to all existing models * bugfix for model loading on convert * debug print removal * Bugfix for masks in dfl_h128 and iae * Update preview display. Add preview scaling to cli * mask notes * Delete training_data_v2.py errant file * training data variables * Fix timelapse function * Add new config items to state file for legacy purposes * Slight GUI tweak * Raise exception if problem with loaded model * Add Tensorboard support (Logs stored in model directory) * ICNR fix * loss bugfix * convert bugfix * Move ini files to config folder. Make TensorBoard optional * Fix training data for unbalanced inputs/outputs * Fix config "none" test * Keep helptext in .ini files when saving config from GUI * Remove frame_dims from alignments * Add no-flip and warp-to-landmarks cli options * Revert OHR to RC4_fix version * Fix lowmem mode on OHR model * padding to variable * Save models in parallel threads * Speed-up of res_block stability * Automated Reflection Padding * Reflect Padding as a training option Includes auto-calculation of proper padding shapes, input_shapes, output_shapes Flag included in config now * rest of reflect padding * Move TB logging to cli. Session info to state file * Add session iterations to state file * Add recent files to menu. GUI code tidy up * [GUI] Fix recent file list update issue * Add correct loss names to TensorBoard logs * Update live graph to use TensorBoard and remove animation * Fix analysis tab. GUI optimizations * Analysis Graph popup to Tensorboard Logs * [GUI] Bug fix for graphing for models with hypens in name * [GUI] Correctly split loss to tabs during training * [GUI] Add loss type selection to analysis graph * Fix store command name in recent files. Switch to correct tab on open * [GUI] Disable training graph when 'no-logs' is selected * Fix graphing race condition * rename original_hires model to unbalanced
844 lines
36 KiB
Python
844 lines
36 KiB
Python
#!/usr/bin/env python3
|
|
""" Custom Loss Functions for faceswap.py
|
|
Losses from:
|
|
keras.contrib
|
|
dfaker: https://github.com/dfaker/df
|
|
shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN"""
|
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
import keras.backend as K
|
|
from keras.layers import Lambda, concatenate
|
|
import tensorflow as tf
|
|
from tensorflow.contrib.distributions import Beta
|
|
|
|
from .normalization import InstanceNormalization
|
|
|
|
|
|
class DSSIMObjective():
|
|
""" DSSIM Loss Function
|
|
|
|
Code copy and pasted, with minor ammendments from:
|
|
https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
|
|
|
MIT License
|
|
|
|
Copyright (c) 2017 Fariz Rahman
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE. """
|
|
# pylint: disable=C0103
|
|
def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0):
|
|
"""
|
|
Difference of Structural Similarity (DSSIM loss function). Clipped
|
|
between 0 and 0.5
|
|
Note : You should add a regularization term like a l2 loss in
|
|
addition to this one.
|
|
Note : In theano, the `kernel_size` must be a factor of the output
|
|
size. So 3 could not be the `kernel_size` for an output of 32.
|
|
# Arguments
|
|
k1: Parameter of the SSIM (default 0.01)
|
|
k2: Parameter of the SSIM (default 0.03)
|
|
kernel_size: Size of the sliding window (default 3)
|
|
max_value: Max value of the output (default 1.0)
|
|
"""
|
|
self.__name__ = 'DSSIMObjective'
|
|
self.kernel_size = kernel_size
|
|
self.k1 = k1
|
|
self.k2 = k2
|
|
self.max_value = max_value
|
|
self.c1 = (self.k1 * self.max_value) ** 2
|
|
self.c2 = (self.k2 * self.max_value) ** 2
|
|
self.dim_ordering = K.image_data_format()
|
|
self.backend = K.backend()
|
|
|
|
@staticmethod
|
|
def __int_shape(x):
|
|
return K.int_shape(x)
|
|
|
|
def __call__(self, y_true, y_pred):
|
|
# There are additional parameters for this function
|
|
# Note: some of the 'modes' for edge behavior do not yet have a
|
|
# gradient definition in the Theano tree and cannot be used for
|
|
# learning
|
|
|
|
kernel = [self.kernel_size, self.kernel_size]
|
|
y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:]))
|
|
y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:]))
|
|
|
|
patches_pred = self.extract_image_patches(y_pred,
|
|
kernel,
|
|
kernel,
|
|
'valid',
|
|
self.dim_ordering)
|
|
patches_true = self.extract_image_patches(y_true,
|
|
kernel,
|
|
kernel,
|
|
'valid',
|
|
self.dim_ordering)
|
|
|
|
# Reshape to get the var in the cells
|
|
_, w, h, c1, c2, c3 = self.__int_shape(patches_pred)
|
|
patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3])
|
|
patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3])
|
|
# Get mean
|
|
u_true = K.mean(patches_true, axis=-1)
|
|
u_pred = K.mean(patches_pred, axis=-1)
|
|
# Get variance
|
|
var_true = K.var(patches_true, axis=-1)
|
|
var_pred = K.var(patches_pred, axis=-1)
|
|
# Get std dev
|
|
covar_true_pred = K.mean(
|
|
patches_true * patches_pred, axis=-1) - u_true * u_pred
|
|
|
|
ssim = (2 * u_true * u_pred + self.c1) * (
|
|
2 * covar_true_pred + self.c2)
|
|
denom = (K.square(u_true) + K.square(u_pred) + self.c1) * (
|
|
var_pred + var_true + self.c2)
|
|
ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero
|
|
return K.mean((1.0 - ssim) / 2.0)
|
|
|
|
@staticmethod
|
|
def _preprocess_padding(padding):
|
|
"""Convert keras' padding to tensorflow's padding.
|
|
# Arguments
|
|
padding: string, `"same"` or `"valid"`.
|
|
# Returns
|
|
a string, `"SAME"` or `"VALID"`.
|
|
# Raises
|
|
ValueError: if `padding` is invalid.
|
|
"""
|
|
if padding == 'same':
|
|
padding = 'SAME'
|
|
elif padding == 'valid':
|
|
padding = 'VALID'
|
|
else:
|
|
raise ValueError('Invalid padding:', padding)
|
|
return padding
|
|
|
|
def extract_image_patches(self, x, ksizes, ssizes, padding='same',
|
|
data_format='channels_last'):
|
|
'''
|
|
Extract the patches from an image
|
|
# Parameters
|
|
x : The input image
|
|
ksizes : 2-d tuple with the kernel size
|
|
ssizes : 2-d tuple with the strides size
|
|
padding : 'same' or 'valid'
|
|
data_format : 'channels_last' or 'channels_first'
|
|
# Returns
|
|
The (k_w,k_h) patches extracted
|
|
TF ==> (batch_size,w,h,k_w,k_h,c)
|
|
TH ==> (batch_size,w,h,c,k_w,k_h)
|
|
'''
|
|
kernel = [1, ksizes[0], ksizes[1], 1]
|
|
strides = [1, ssizes[0], ssizes[1], 1]
|
|
padding = self._preprocess_padding(padding)
|
|
if data_format == 'channels_first':
|
|
x = K.permute_dimensions(x, (0, 2, 3, 1))
|
|
_, _, _, ch_i = K.int_shape(x)
|
|
patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
|
|
padding)
|
|
# Reshaping to fit Theano
|
|
_, w, h, ch = K.int_shape(patches)
|
|
patches = tf.reshape(tf.transpose(tf.reshape(patches,
|
|
[-1, w, h,
|
|
tf.floordiv(ch, ch_i),
|
|
ch_i]),
|
|
[0, 1, 2, 4, 3]),
|
|
[-1, w, h, ch_i, ksizes[0], ksizes[1]])
|
|
if data_format == 'channels_last':
|
|
patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
|
|
return patches
|
|
|
|
# <<< START: from Dfaker >>> #
|
|
class PenalizedLoss(): # pylint: disable=too-few-public-methods
|
|
""" Penalized Loss
|
|
from: https://github.com/dfaker/df """
|
|
def __init__(self, mask, loss_func, mask_prop=1.0):
|
|
self.mask = mask
|
|
self.loss_func = loss_func
|
|
self.mask_prop = mask_prop
|
|
self.mask_as_k_inv_prop = 1-mask_prop
|
|
|
|
def __call__(self, y_true, y_pred):
|
|
# pylint: disable=invalid-name
|
|
tro, tgo, tbo = tf.split(y_true, 3, 3)
|
|
pro, pgo, pbo = tf.split(y_pred, 3, 3)
|
|
|
|
tr = tro
|
|
tg = tgo
|
|
tb = tbo
|
|
|
|
pr = pro
|
|
pg = pgo
|
|
pb = pbo
|
|
m = self.mask
|
|
|
|
m = m * self.mask_prop
|
|
m += self.mask_as_k_inv_prop
|
|
tr *= m
|
|
tg *= m
|
|
tb *= m
|
|
|
|
pr *= m
|
|
pg *= m
|
|
pb *= m
|
|
|
|
y = tf.concat([tr, tg, tb], 3)
|
|
p = tf.concat([pr, pg, pb], 3)
|
|
|
|
# yo = tf.stack([tro,tgo,tbo],3)
|
|
# po = tf.stack([pro,pgo,pbo],3)
|
|
|
|
return self.loss_func(y, p)
|
|
# <<< END: from Dfaker >>> #
|
|
|
|
|
|
# <<< START: from Shoanlu GAN >>> #
|
|
def first_order(var_x, axis=1):
|
|
""" First Order Function from Shoanlu GAN """
|
|
img_nrows = var_x.shape[1]
|
|
img_ncols = var_x.shape[2]
|
|
if axis == 1:
|
|
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :])
|
|
if axis == 2:
|
|
return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :])
|
|
return None
|
|
|
|
|
|
def calc_loss(pred, target, loss='l2'):
|
|
""" Calculate Loss from Shoanlu GAN """
|
|
if loss.lower() == "l2":
|
|
return K.mean(K.square(pred - target))
|
|
if loss.lower() == "l1":
|
|
return K.mean(K.abs(pred - target))
|
|
if loss.lower() == "cross_entropy":
|
|
return -K.mean(K.log(pred + K.epsilon()) * target +
|
|
K.log(1 - pred + K.epsilon()) * (1 - target))
|
|
raise ValueError('Recieve an unknown loss type: {}.'.format(loss))
|
|
|
|
|
|
def cyclic_loss(net_g1, net_g2, real1):
|
|
""" Cyclic Loss Function from Shoanlu GAN """
|
|
fake2 = net_g2(real1)[-1] # fake2 ABGR
|
|
fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR
|
|
cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR
|
|
cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR
|
|
loss = calc_loss(cyclic1, real1, loss='l1')
|
|
return loss
|
|
|
|
|
|
def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights):
|
|
""" Adversarial Loss Function from Shoanlu GAN """
|
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
|
fake = alpha * fake_bgr + (1-alpha) * distorted
|
|
|
|
if gan_training == "mixup_LSGAN":
|
|
dist = Beta(0.2, 0.2)
|
|
lam = dist.sample()
|
|
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
|
pred_fake = net_d(concatenate([fake, distorted]))
|
|
pred_mixup = net_d(mixup)
|
|
loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2")
|
|
loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2")
|
|
mixup2 = lam * concatenate([real,
|
|
distorted]) + (1 - lam) * concatenate([fake_bgr,
|
|
distorted])
|
|
pred_fake_bgr = net_d(concatenate([fake_bgr, distorted]))
|
|
pred_mixup2 = net_d(mixup2)
|
|
loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2")
|
|
loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2")
|
|
elif gan_training == "relativistic_avg_LSGAN":
|
|
real_pred = net_d(concatenate([real, distorted]))
|
|
fake_pred = net_d(concatenate([fake, distorted]))
|
|
loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2
|
|
loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2
|
|
loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred)))
|
|
|
|
fake_pred2 = net_d(concatenate([fake_bgr, distorted]))
|
|
loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
|
|
K.ones_like(fake_pred2)))/2
|
|
loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
|
|
K.zeros_like(fake_pred2)))/2
|
|
loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) -
|
|
K.zeros_like(fake_pred2)))/2
|
|
loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) -
|
|
K.ones_like(fake_pred2)))/2
|
|
else:
|
|
raise ValueError("Receive an unknown GAN training method: {gan_training}")
|
|
return loss_d, loss_g
|
|
|
|
|
|
def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights):
|
|
""" Reconstruction Loss Function from Shoanlu GAN """
|
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
|
|
|
loss_g = 0
|
|
loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1")
|
|
loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real)))
|
|
|
|
for out in model_outputs[:-1]:
|
|
out_size = out.get_shape().as_list()
|
|
resized_real = tf.image.resize_images(real, out_size[1:3])
|
|
loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1")
|
|
return loss_g
|
|
|
|
|
|
def edge_loss(real, fake_abgr, mask_eyes, **weights):
|
|
""" Edge Loss Function from Shoanlu GAN """
|
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
|
|
|
loss_g = 0
|
|
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1),
|
|
first_order(real, axis=1), "l1")
|
|
loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2),
|
|
first_order(real, axis=2), "l1")
|
|
shape_mask_eyes = mask_eyes.get_shape().as_list()
|
|
resized_mask_eyes = tf.image.resize_images(mask_eyes,
|
|
[shape_mask_eyes[1]-1, shape_mask_eyes[2]-1])
|
|
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
|
|
(first_order(fake_bgr, axis=1) -
|
|
first_order(real, axis=1))))
|
|
loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes *
|
|
(first_order(fake_bgr, axis=2) -
|
|
first_order(real, axis=2))))
|
|
return loss_g
|
|
|
|
|
|
def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights):
|
|
""" Perceptual Loss Function from Shoanlu GAN """
|
|
alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr)
|
|
fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr)
|
|
fake = alpha * fake_bgr + (1-alpha) * distorted
|
|
|
|
def preprocess_vggface(var_x):
|
|
var_x = (var_x + 1) / 2 * 255 # channel order: BGR
|
|
var_x -= [91.4953, 103.8827, 131.0912]
|
|
return var_x
|
|
|
|
real_sz224 = tf.image.resize_images(real, [224, 224])
|
|
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
|
dist = Beta(0.2, 0.2)
|
|
lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1.
|
|
mixup = lam*fake_bgr + (1-lam)*fake
|
|
fake_sz224 = tf.image.resize_images(mixup, [224, 224])
|
|
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
|
real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224)
|
|
fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224)
|
|
|
|
# Apply instance norm on VGG(ResNet) features
|
|
# From MUNIT https://github.com/NVlabs/MUNIT
|
|
loss_g = 0
|
|
|
|
def instnorm():
|
|
return InstanceNormalization()
|
|
|
|
loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7),
|
|
instnorm()(real_feat7), "l2")
|
|
loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28),
|
|
instnorm()(real_feat28), "l2")
|
|
loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55),
|
|
instnorm()(real_feat55), "l2")
|
|
loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112),
|
|
instnorm()(real_feat112), "l2")
|
|
return loss_g
|
|
|
|
# <<< END: from Shoanlu GAN >>> #
|
|
|
|
|
|
def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0):
|
|
'''
|
|
generalized function used to return a large variety of mathematical loss functions
|
|
primary benefit is smooth, differentiable version of L1 loss
|
|
|
|
Barron, J. A More General Robust Loss Function
|
|
https://arxiv.org/pdf/1701.03077.pdf
|
|
|
|
Parameters:
|
|
a: penalty factor. larger number give larger weight to large deviations
|
|
c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 )
|
|
|
|
Return:
|
|
a loss value from the results of function(y_pred - y_true)
|
|
|
|
Example:
|
|
a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss
|
|
a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss
|
|
'''
|
|
x = y_pred - y_true
|
|
loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 )
|
|
return K.mean(loss, axis=-1) * c
|
|
|
|
|
|
def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0):
|
|
h = c
|
|
w = c
|
|
x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0)
|
|
loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w))
|
|
loss += 1e-10
|
|
return K.mean(loss, axis=-1)
|
|
|
|
|
|
def gradient_loss(y_true, y_pred):
|
|
'''
|
|
Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions.
|
|
These gradients are then compared between the ground truth and the predicted image and the difference is taken.
|
|
The difference used is a smooth L1 norm ( approximate to MAE but differable at zero )
|
|
When used as a loss, its minimization will result in predicted images approaching the same level of sharpness
|
|
/ blurriness as the ground truth.
|
|
|
|
TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014
|
|
(http://downloads.hindawi.com/journals/mpe/2014/790547.pdf)
|
|
|
|
Parameters:
|
|
y_true: The predicted frames at each scale.
|
|
y_true: The ground truth frames at each scale
|
|
|
|
Return:
|
|
The GD loss.
|
|
'''
|
|
|
|
assert 4 == K.ndim(y_true)
|
|
y_true.set_shape([None,80,80,3])
|
|
y_pred.set_shape([None,80,80,3])
|
|
TV_weight = 1.0
|
|
TV2_weight = 1.0
|
|
loss = 0.0
|
|
|
|
def diff_x(X):
|
|
Xleft = X[:, :, 1, :] - X[:, :, 0, :]
|
|
Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2)
|
|
Xright = X[:, :, -1, :] - X[:, :, -2, :]
|
|
Xout = [Xleft] + Xinner + [Xright]
|
|
Xout = tf.stack(Xout,axis=2)
|
|
return Xout * 0.5
|
|
|
|
def diff_y(X):
|
|
Xtop = X[:, 1, :, :] - X[:, 0, :, :]
|
|
Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1)
|
|
Xbot = X[:, -1, :, :] - X[:, -2, :, :]
|
|
Xout = [Xtop] + Xinner + [Xbot]
|
|
Xout = tf.stack(Xout,axis=1)
|
|
return Xout * 0.5
|
|
|
|
def diff_xx(X):
|
|
Xleft = X[:, :, 1, :] + X[:, :, 0, :]
|
|
Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2)
|
|
Xright = X[:, :, -1, :] + X[:, :, -2, :]
|
|
Xout = [Xleft] + Xinner + [Xright]
|
|
Xout = tf.stack(Xout,axis=2)
|
|
return Xout - 2.0 * X
|
|
|
|
def diff_yy(X):
|
|
Xtop = X[:, 1, :, :] + X[:, 0, :, :]
|
|
Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1)
|
|
Xbot = X[:, -1, :, :] + X[:, -2, :, :]
|
|
Xout = [Xtop] + Xinner + [Xbot]
|
|
Xout = tf.stack(Xout,axis=1)
|
|
return Xout - 2.0 * X
|
|
|
|
def diff_xy(X):
|
|
#xout1
|
|
top_left = X[:, 1, 1, :]+X[:, 0, 0, :]
|
|
inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1)
|
|
bot_left = X[:, -1, 1, :]+X[:, -2, 0, :]
|
|
X_left = [top_left] + inner_left + [bot_left]
|
|
X_left = tf.stack(X_left, axis=1)
|
|
|
|
top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :]
|
|
mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1)
|
|
bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :]
|
|
X_mid = [top_mid] + mid_mid + [bot_mid]
|
|
X_mid = tf.stack(X_mid, axis=1)
|
|
|
|
top_right = X[:, 1, -1, :]+X[:, 0, -2, :]
|
|
inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1)
|
|
bot_right = X[:, -1, -1, :]+X[:, -2, -2, :]
|
|
X_right = [top_right] + inner_right + [bot_right]
|
|
X_right = tf.stack(X_right, axis=1)
|
|
|
|
X_mid = tf.unstack(X_mid, axis=2)
|
|
Xout1 = [X_left] + X_mid + [X_right]
|
|
Xout1 = tf.stack(Xout1, axis=2)
|
|
|
|
#Xout2
|
|
top_left = X[:, 0, 1, :]+X[:, 1, 0, :]
|
|
inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1)
|
|
bot_left = X[:, -2, 1, :]+X[:, -1, 0, :]
|
|
X_left = [top_left] + inner_left + [bot_left]
|
|
X_left = tf.stack(X_left, axis=1)
|
|
|
|
top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :]
|
|
mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1)
|
|
bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :]
|
|
X_mid = [top_mid] + mid_mid + [bot_mid]
|
|
X_mid = tf.stack(X_mid, axis=1)
|
|
|
|
top_right = X[:, 0, -1, :]+X[:, 1, -2, :]
|
|
inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1)
|
|
bot_right = X[:, -2, -1, :]+X[:, -1, -2, :]
|
|
X_right = [top_right] + inner_right + [bot_right]
|
|
X_right = tf.stack(X_right, axis=1)
|
|
|
|
X_mid = tf.unstack(X_mid, axis=2)
|
|
Xout2 = [X_left] + X_mid + [X_right]
|
|
Xout2 = tf.stack(Xout2, axis=2)
|
|
|
|
return (Xout1 - Xout2) * 0.25
|
|
|
|
loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) +
|
|
generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) )
|
|
|
|
loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) +
|
|
generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) +
|
|
2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) )
|
|
|
|
return loss / ( TV_weight + TV2_weight )
|
|
|
|
|
|
def scharr_edges(image, magnitude):
|
|
'''
|
|
Returns a tensor holding modified Scharr edge maps.
|
|
Arguments:
|
|
image: Image tensor with shape [batch_size, h, w, d] and type float32.
|
|
The image(s) must be 2x2 or larger.
|
|
magnitude: Boolean to determine if the edge magnitude or edge direction is returned
|
|
Returns:
|
|
Tensor holding edge maps for each channel. Returns a tensor with shape
|
|
[batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]],
|
|
[dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter.
|
|
'''
|
|
|
|
# Define vertical and horizontal Scharr filters.
|
|
static_image_shape = image.get_shape()
|
|
image_shape = tf.shape(image)
|
|
'''
|
|
#modified 3x3 Scharr
|
|
kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]],
|
|
[[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]]
|
|
'''
|
|
# 5x5 Scharr
|
|
kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]],
|
|
[[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]]
|
|
num_kernels = len(kernels)
|
|
kernels = numpy.transpose(numpy.asarray(kernels), (1, 2, 0))
|
|
kernels = numpy.expand_dims(kernels, -2) / numpy.sum(numpy.abs(kernels))
|
|
kernels_tf = tf.constant(kernels, dtype=image.dtype)
|
|
kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters')
|
|
|
|
# Use depth-wise convolution to calculate edge maps per channel.
|
|
pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]]
|
|
padded = tf.pad(image, pad_sizes, mode='REFLECT')
|
|
|
|
# Output tensor has shape [batch_size, h, w, d * num_kernels].
|
|
strides = [1, 1, 1, 1]
|
|
output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID')
|
|
|
|
# Reshape to [batch_size, h, w, d, num_kernels].
|
|
shape = tf.concat([image_shape, [num_kernels]], 0)
|
|
output = tf.reshape(output, shape=shape)
|
|
output.set_shape(static_image_shape.concatenate([num_kernels]))
|
|
|
|
if magnitude: # magnitude of edges
|
|
output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1))
|
|
else: # direction of edges
|
|
output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1])))
|
|
|
|
return output
|
|
|
|
|
|
def gmsd_loss(y_true,y_pred):
|
|
'''
|
|
Improved image quality metric over MS-SSIM with easier calc
|
|
http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm
|
|
https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf
|
|
'''
|
|
true_edge_mag = scharr_edges(y_true,True)
|
|
pred_edge_mag = scharr_edges(y_pred,True)
|
|
c = 0.002
|
|
upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c
|
|
lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c
|
|
GMS = tf.div(upper,lower)
|
|
_mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True)
|
|
GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1]
|
|
return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded
|
|
|
|
|
|
def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)):
|
|
'''
|
|
Computes the MS-SSIM between img1 and img2.
|
|
This function assumes that `img1` and `img2` are image batches, i.e. the last
|
|
three dimensions are [height, width, channels].
|
|
Note: The true SSIM is only defined on grayscale. This function does not
|
|
perform any colorspace transform. (If input is already YUV, then it will
|
|
compute YUV SSIM average.)
|
|
Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
|
|
structural similarity for image quality assessment." Signals, Systems and
|
|
Computers, 2004.
|
|
Arguments:
|
|
img1: First image batch.
|
|
img2: Second image batch. Must have the same rank as img1.
|
|
max_val: The dynamic range of the images (i.e., the difference between the
|
|
maximum the and minimum allowed values).
|
|
power_factors: Iterable of weights for each of the scales. The number of
|
|
scales used is the length of the list. Index 0 is the unscaled
|
|
resolution's weight and each increasing scale corresponds to the image
|
|
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
|
|
0.1333), which are the values obtained in the original paper.
|
|
Returns:
|
|
A tensor containing an MS-SSIM value for each image in batch. The values
|
|
are in range [0, 1]. Returns a tensor with shape:
|
|
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
|
'''
|
|
|
|
def _verify_compatible_image_shapes(img1, img2):
|
|
'''
|
|
Checks if two image tensors are compatible for applying SSIM or PSNR.
|
|
This function checks if two sets of images have ranks at least 3, and if the
|
|
last three dimensions match.
|
|
Args:
|
|
img1: Tensor containing the first image batch.
|
|
img2: Tensor containing the second image batch.
|
|
Returns:
|
|
A tuple containing: the first tensor shape, the second tensor shape, and a
|
|
list of control_flow_ops.Assert() ops implementing the checks.
|
|
Raises:
|
|
ValueError: When static shape check fails.
|
|
'''
|
|
shape1 = img1.get_shape().with_rank_at_least(3)
|
|
shape2 = img2.get_shape().with_rank_at_least(3)
|
|
shape1[-3:].assert_is_compatible_with(shape2[-3:])
|
|
|
|
if shape1.ndims is not None and shape2.ndims is not None:
|
|
for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])):
|
|
if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
|
|
raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2))
|
|
|
|
# Now assign shape tensors.
|
|
shape1, shape2 = tf.shape_n([img1, img2])
|
|
|
|
# TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable.
|
|
checks = []
|
|
checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10))
|
|
checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10))
|
|
|
|
return shape1, shape2, checks
|
|
|
|
def _ssim_per_channel(img1, img2, max_val=1.0):
|
|
'''
|
|
Computes SSIM index between img1 and img2 per color channel.
|
|
This function matches the standard SSIM implementation from:
|
|
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
|
|
quality assessment: from error visibility to structural similarity. IEEE
|
|
transactions on image processing.
|
|
Details:
|
|
- 11x11 Gaussian filter of width 1.5 is used.
|
|
- k1 = 0.01, k2 = 0.03 as in the original paper.
|
|
Args:
|
|
img1: First image batch.
|
|
img2: Second image batch.
|
|
max_val: The dynamic range of the images (i.e., the difference between the
|
|
maximum the and minimum allowed values).
|
|
Returns:
|
|
A pair of tensors containing and channel-wise SSIM and contrast-structure
|
|
values. The shape is [..., channels].
|
|
'''
|
|
|
|
def _fspecial_gauss(size, sigma):
|
|
'''
|
|
Function to mimic the 'fspecial' gaussian MATLAB function.
|
|
'''
|
|
size = tf.convert_to_tensor(size, 'int32')
|
|
sigma = tf.convert_to_tensor(sigma)
|
|
|
|
coords = tf.cast(tf.range(size), sigma.dtype)
|
|
coords -= tf.cast(size - 1, sigma.dtype) / 2.0
|
|
|
|
g = tf.square(coords)
|
|
g *= -0.5 / tf.square(sigma)
|
|
|
|
g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1])
|
|
g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
|
|
g = tf.nn.softmax(g)
|
|
return tf.reshape(g, shape=[size, size, 1, 1])
|
|
|
|
def _ssim_helper(x, y, max_val, kernel, compensation=1.0):
|
|
'''
|
|
Helper function for computing SSIM.
|
|
SSIM estimates covariances with weighted sums. The default parameters
|
|
use a biased estimate of the covariance:
|
|
Suppose `reducer` is a weighted sum, then the mean estimators are
|
|
\mu_x = \sum_i w_i x_i,
|
|
\mu_y = \sum_i w_i y_i,
|
|
where w_i's are the weighted-sum weights, and covariance estimator is
|
|
cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
|
with assumption \sum_i w_i = 1. This covariance estimator is biased, since
|
|
E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
|
|
For SSIM measure with unbiased covariance estimators, pass as `compensation`
|
|
argument (1 - \sum_i w_i ^ 2).
|
|
Arguments:
|
|
x: First set of images.
|
|
y: Second set of images.
|
|
reducer: Function that computes 'local' averages from set of images.
|
|
For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]),
|
|
and for convolutional version, this is usually tf.nn.avg_pool or
|
|
tf.nn.conv2d with weighted-sum kernel.
|
|
max_val: The dynamic range (i.e., the difference between the maximum
|
|
possible allowed value and the minimum allowed value).
|
|
compensation: Compensation factor. See above.
|
|
Returns:
|
|
A pair containing the luminance measure, and the contrast-structure measure.
|
|
'''
|
|
|
|
def reducer(x, kernel):
|
|
shape = tf.shape(x)
|
|
x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0))
|
|
y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
|
|
return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0))
|
|
|
|
_SSIM_K1 = 0.01
|
|
_SSIM_K2 = 0.03
|
|
|
|
c1 = (_SSIM_K1 * max_val) ** 2
|
|
c2 = (_SSIM_K2 * max_val) ** 2
|
|
|
|
# SSIM luminance measure is
|
|
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
|
|
mean0 = reducer(x, kernel)
|
|
mean1 = reducer(y, kernel)
|
|
num0 = mean0 * mean1 * 2.0
|
|
den0 = tf.square(mean0) + tf.square(mean1)
|
|
luminance = (num0 + c1) / (den0 + c1)
|
|
|
|
# SSIM contrast-structure measure is
|
|
# (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
|
|
# Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
|
|
# cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
|
# = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
|
|
num1 = reducer(x * y, kernel) * 2.0
|
|
den1 = reducer(tf.square(x) + tf.square(y), kernel)
|
|
c2 *= compensation
|
|
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
|
|
|
# SSIM score is the product of the luminance and contrast-structure measures.
|
|
return luminance, cs
|
|
|
|
filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due
|
|
filter_sigma = tf.constant(1.5, dtype=img1.dtype)
|
|
|
|
shape1, shape2 = tf.shape_n([img1, img2])
|
|
checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8),
|
|
tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)]
|
|
|
|
# Enforce the check to run before computation.
|
|
with tf.control_dependencies(checks):
|
|
img1 = tf.identity(img1)
|
|
|
|
# TODO(sjhwang): Try to cache kernels and compensation factor.
|
|
kernel = _fspecial_gauss(filter_size, filter_sigma)
|
|
kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1])
|
|
|
|
# The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
|
|
# but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
|
|
compensation = 1.0
|
|
|
|
# TODO(sjhwang): Try FFT.
|
|
# TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
|
|
# 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter.
|
|
|
|
luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation)
|
|
|
|
# Average over the second and the third from the last: height, width.
|
|
axes = tf.constant([-3, -2], dtype='int32')
|
|
ssim_val = tf.reduce_mean(luminance * cs, axes)
|
|
cs = tf.reduce_mean(cs, axes)
|
|
return ssim_val, cs
|
|
|
|
def do_pad(images, remainder):
|
|
padding = tf.expand_dims(remainder, -1)
|
|
padding = tf.pad(padding, [[1, 0], [1, 0]])
|
|
return [tf.pad(x, padding, mode='SYMMETRIC') for x in images]
|
|
|
|
# Shape checking.
|
|
shape1 = img1.get_shape().with_rank_at_least(3)
|
|
shape2 = img2.get_shape().with_rank_at_least(3)
|
|
shape1[-3:].merge_with(shape2[-3:])
|
|
|
|
with tf.name_scope(None, 'MS-SSIM', [img1, img2]):
|
|
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
|
with tf.control_dependencies(checks):
|
|
img1 = tf.identity(img1)
|
|
|
|
# Need to convert the images to float32. Scale max_val accordingly so that
|
|
# SSIM is computed correctly.
|
|
max_val = tf.cast(max_val, img1.dtype)
|
|
max_val = tf.image.convert_image_dtype(max_val, 'float32')
|
|
img1 = tf.image.convert_image_dtype(img1, 'float32')
|
|
img2 = tf.image.convert_image_dtype(img2, 'float32')
|
|
|
|
imgs = [img1, img2]
|
|
shapes = [shape1, shape2]
|
|
|
|
# img1 and img2 are assumed to be a (multi-dimensional) batch of
|
|
# 3-dimensional images (height, width, channels). `heads` contain the batch
|
|
# dimensions, and `tails` contain the image dimensions.
|
|
heads = [s[:-3] for s in shapes]
|
|
tails = [s[-3:] for s in shapes]
|
|
|
|
divisor = [1, 2, 2, 1]
|
|
divisor_tensor = tf.constant(divisor[1:], dtype='int32')
|
|
|
|
mcs = []
|
|
for k in range(len(power_factors)):
|
|
with tf.name_scope(None, 'Scale%d' % k, imgs):
|
|
if k > 0:
|
|
# Avg pool takes rank 4 tensors. Flatten leading dimensions.
|
|
flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)]
|
|
|
|
remainder = tails[0] % divisor_tensor
|
|
need_padding = tf.reduce_any(tf.not_equal(remainder, 0))
|
|
padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder),
|
|
lambda: flat_imgs)
|
|
|
|
downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID')
|
|
for x in padded]
|
|
tails = [x[1:] for x in tf.shape_n(downscaled)]
|
|
imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)]
|
|
|
|
# Overwrite previous ssim value since we only need the last one.
|
|
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val)
|
|
mcs.append(tf.nn.relu(cs))
|
|
|
|
# Remove the cs score for the last scale. In the MS-SSIM calculation,
|
|
# we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
|
|
mcs.pop() # Remove the cs score for the last scale.
|
|
mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1)
|
|
# Take weighted geometric mean across the scale axis.
|
|
ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1])
|
|
|
|
return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
|
|
|
|
|
|
def ms_ssim_loss(y_true,y_pred):
|
|
MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1)
|
|
return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded
|