1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/plugins/train/model/_base.py
torzdf 66ed005ef3
Optimize Data Augmentation (#881)
* Move image utils to lib.image
* Add .pylintrc file
* Remove some cv2 pylint ignores
* TrainingData: Load images from disk in batches
* TrainingData: get_landmarks to batch
* TrainingData: transform and flip to batches
* TrainingData: Optimize color augmentation
* TrainingData: Optimize target and random_warp
* TrainingData - Convert _get_closest_match for batching
* TrainingData: Warp To Landmarks optimized
* Save models to threadpoolexecutor
* Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings
* Masks: Set dtype and threshold for lib.masks based on input face
* Docstrings and Documentation
2019-09-24 12:16:05 +01:00

1029 lines
45 KiB
Python

#!/usr/bin/env python3
""" Base class for Models. ALL Models should at least inherit from this class
When inheriting model_data should be a list of NNMeta objects.
See the class for details.
"""
import logging
import os
import sys
import time
from concurrent import futures
from json import JSONDecodeError
import keras
from keras import losses
from keras import backend as K
from keras.layers import Input
from keras.models import load_model, Model
from keras.utils import get_custom_objects, multi_gpu_model
from lib import Serializer
from lib.model.backup_restore import Backup
from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask_loss_wrapper,
generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur)
from lib.model.nn_blocks import NNBlocks
from lib.model.optimizers import Adam
from lib.utils import deprecation_warning, FaceswapError
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG = None
class ModelBase():
""" Base class that all models should inherit from """
def __init__(self,
model_dir,
gpus=1,
configfile=None,
snapshot_interval=0,
no_logs=False,
warp_to_landmarks=False,
augment_color=True,
no_flip=False,
training_image_size=256,
alignments_paths=None,
preview_scale=100,
input_shape=None,
encoder_dim=None,
trainer="original",
pingpong=False,
memory_saving_gradients=False,
optimizer_savings=False,
predict=False):
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
"snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
"%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
"preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
"pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, "
"predict: %s)",
self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval,
no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size,
alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong,
memory_saving_gradients, optimizer_savings, predict)
self.predict = predict
self.model_dir = model_dir
self.vram_savings = VRAMSavings(pingpong, optimizer_savings, memory_saving_gradients)
self.backup = Backup(self.model_dir, self.name)
self.gpus = gpus
self.configfile = configfile
self.input_shape = input_shape
self.encoder_dim = encoder_dim
self.trainer = trainer
self.load_config() # Load config if plugin has not already referenced it
self.state = State(self.model_dir,
self.name,
self.config_changeable_items,
no_logs,
self.vram_savings.pingpong,
training_image_size)
self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"],
use_icnr_init=self.config["icnr_init"],
use_convaware_init=self.config["conv_aware_init"],
use_reflect_padding=self.config["reflect_padding"],
first_run=self.state.first_run)
self.is_legacy = False
self.rename_legacy()
self.load_state_info()
self.networks = dict() # Networks for the model
self.predictors = dict() # Predictors for model
self.history = dict() # Loss history per save iteration)
# Training information specific to the model should be placed in this
# dict for reference by the trainer.
self.training_opts = {"alignments": alignments_paths,
"preview_scaling": preview_scale / 100,
"warp_to_landmarks": warp_to_landmarks,
"augment_color": augment_color,
"no_flip": no_flip,
"pingpong": self.vram_savings.pingpong,
"snapshot_interval": snapshot_interval}
if self.multiple_models_in_folder:
deprecation_warning("Support for multiple model types within the same folder",
additional_info="Please split each model into separate folders to "
"avoid issues in future.")
self.build()
self.set_training_data()
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
@property
def config_section(self):
""" The section name for loading config """
retval = ".".join(self.__module__.split(".")[-2:])
logger.debug(retval)
return retval
@property
def config(self):
""" Return config dict for current plugin """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = self.config_section
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name, configfile=self.configfile).config_dict
return _CONFIG
@property
def config_changeable_items(self):
""" Return the dict of config items that can be updated after the model
has been created """
return Config(self.config_section, configfile=self.configfile).changeable_items
@property
def name(self):
""" Set the model name based on the subclass """
basename = os.path.basename(sys.modules[self.__module__].__file__)
retval = os.path.splitext(basename)[0].lower()
logger.debug("model name: '%s'", retval)
return retval
@property
def models_exist(self):
""" Return if all files exist and clear session """
retval = all([os.path.isfile(model.filename) for model in self.networks.values()])
logger.debug("Pre-existing models exist: %s", retval)
return retval
@property
def multiple_models_in_folder(self):
""" Return true if there are multiple model types in the same folder, else false """
model_files = [fname for fname in os.listdir(str(self.model_dir)) if fname.endswith(".h5")]
retval = False if not model_files else os.path.commonprefix(model_files) == ""
logger.debug("model_files: %s, retval: %s", model_files, retval)
return retval
@property
def output_shapes(self):
""" Return the output shapes from the main AutoEncoder """
out = list()
for predictor in self.predictors.values():
out.extend([K.int_shape(output)[-3:] for output in predictor.outputs])
break # Only get output from one autoencoder. Shapes are the same
return [tuple(shape) for shape in out]
@property
def output_shape(self):
""" The output shape of the model (shape of largest face output) """
return self.output_shapes[self.largest_face_index]
@property
def largest_face_index(self):
""" Return the index from model.outputs of the largest face
Required for multi-output model prediction. The largest face
is assumed to be the final output
"""
sizes = [shape[1] for shape in self.output_shapes if shape[2] == 3]
if not sizes:
return None
max_face = max(sizes)
retval = [idx for idx, shape in enumerate(self.output_shapes)
if shape[1] == max_face and shape[2] == 3][0]
logger.debug(retval)
return retval
@property
def largest_mask_index(self):
""" Return the index from model.outputs of the largest mask
Required for multi-output model prediction. The largest face
is assumed to be the final output
"""
sizes = [shape[1] for shape in self.output_shapes if shape[2] == 1]
if not sizes:
return None
max_mask = max(sizes)
retval = [idx for idx, shape in enumerate(self.output_shapes)
if shape[1] == max_mask and shape[2] == 1][0]
logger.debug(retval)
return retval
def load_config(self):
""" Load the global config for reference in self.config """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = self.config_section
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name, configfile=self.configfile).config_dict
def set_training_data(self):
""" Override to set model specific training data.
super() this method for defaults otherwise be sure to add """
logger.debug("Setting training data")
# Force number of preview images to between 2 and 16
self.training_opts["training_size"] = self.state.training_size
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
self.training_opts["mask_type"] = self.config.get("mask_type", None)
self.training_opts["coverage_ratio"] = self.calculate_coverage_ratio()
logger.debug("Set training data: %s", self.training_opts)
def calculate_coverage_ratio(self):
""" Coverage must be a ratio, leading to a cropped shape divisible by 2 """
coverage_ratio = self.config.get("coverage", 62.5) / 100
logger.debug("Requested coverage_ratio: %s", coverage_ratio)
cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2
coverage_ratio = cropped_size / self.state.training_size
logger.debug("Final coverage_ratio: %s", coverage_ratio)
return coverage_ratio
def build(self):
""" Build the model. Override for custom build methods """
self.add_networks()
self.load_models(swapped=False)
inputs = self.get_inputs()
try:
self.build_autoencoders(inputs)
except ValueError as err:
if "must be from the same graph" in str(err).lower():
msg = ("There was an error loading saved weights. This is most likely due to "
"model corruption during a previous save."
"\nYou should restore weights from a snapshot or from backup files. "
"You can use the 'Restore' Tool to restore from backup.")
raise FaceswapError(msg) from err
if "multi_gpu_model" in str(err).lower():
raise FaceswapError(str(err)) from err
raise err
self.log_summary()
self.compile_predictors(initialize=True)
def get_inputs(self):
""" Return the inputs for the model """
logger.debug("Getting inputs")
inputs = [Input(shape=self.input_shape, name="face_in")]
output_network = [network for network in self.networks.values() if network.is_output][0]
mask_idx = [idx for idx, name in enumerate(output_network.output_names)
if name.startswith("mask")]
if mask_idx:
# Add the final mask shape as input
mask_shape = output_network.output_shapes[mask_idx[0]]
inputs.append(Input(shape=mask_shape[1:], name="mask_in"))
logger.debug("Got inputs: %s", inputs)
return inputs
def build_autoencoders(self, inputs):
""" Override for Model Specific autoencoder builds
Inputs is defined in self.get_inputs() and is standardized for all models
if will generally be in the order:
[face (the input for image),
mask (the input for mask if it is used)]
"""
raise NotImplementedError
def add_networks(self):
""" Override to add neural networks """
raise NotImplementedError
def load_state_info(self):
""" Load the input shape from state file if it exists """
logger.debug("Loading Input Shape from State file")
if not self.state.inputs:
logger.debug("No input shapes saved. Using model config")
return
if not self.state.face_shapes:
logger.warning("Input shapes stored in State file, but no matches for 'face'."
"Using model config")
return
input_shape = self.state.face_shapes[0]
logger.debug("Setting input shape from state file: %s", input_shape)
self.input_shape = input_shape
def add_network(self, network_type, side, network, is_output=False):
""" Add a NNMeta object """
logger.debug("network_type: '%s', side: '%s', network: '%s', is_output: %s",
network_type, side, network, is_output)
filename = "{}_{}".format(self.name, network_type.lower())
name = network_type.lower()
if side:
side = side.lower()
filename += "_{}".format(side.upper())
name += "_{}".format(side)
filename += ".h5"
logger.debug("name: '%s', filename: '%s'", name, filename)
self.networks[name] = NNMeta(str(self.model_dir / filename),
network_type,
side,
network,
is_output)
def add_predictor(self, side, model):
""" Add a predictor to the predictors dictionary """
logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
if self.gpus > 1:
logger.debug("Converting to multi-gpu: side %s", side)
model = multi_gpu_model(model, self.gpus)
self.predictors[side] = model
if not self.state.inputs:
self.store_input_shapes(model)
def store_input_shapes(self, model):
""" Store the input and output shapes to state """
logger.debug("Adding input shapes to state for model")
inputs = {tensor.name: K.int_shape(tensor)[-3:] for tensor in model.inputs}
if not any(inp for inp in inputs.keys() if inp.startswith("face")):
raise ValueError("No input named 'face' was found. Check your input naming. "
"Current input names: {}".format(inputs))
# Make sure they are all ints so that it can be json serialized
inputs = {key: tuple(int(i) for i in val) for key, val in inputs.items()}
self.state.inputs = inputs
logger.debug("Added input shapes: %s", self.state.inputs)
def reset_pingpong(self):
""" Reset the models for pingpong training """
logger.debug("Resetting models")
# Clear models and graph
self.predictors = dict()
K.clear_session()
# Load Models for current training run
for model in self.networks.values():
model.network = Model.from_config(model.config)
model.network.set_weights(model.weights)
inputs = self.get_inputs()
self.build_autoencoders(inputs)
self.compile_predictors(initialize=False)
logger.debug("Reset models")
def compile_predictors(self, initialize=True):
""" Compile the predictors """
logger.debug("Compiling Predictors")
learning_rate = self.config.get("learning_rate", 5e-5)
optimizer = self.get_optimizer(lr=learning_rate, beta_1=0.5, beta_2=0.999)
for side, model in self.predictors.items():
loss = Loss(model.inputs, model.outputs)
model.compile(optimizer=optimizer, loss=loss.funcs)
if initialize:
self.state.add_session_loss_names(side, loss.names)
self.history[side] = list()
logger.debug("Compiled Predictors. Losses: %s", loss.names)
def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=invalid-name
""" Build and return Optimizer """
opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2)
if (self.config.get("clipnorm", False) and
keras.backend.backend() != "plaidml.keras.backend"):
# NB: Clipnorm is ballooning VRAM useage, which is not expected behaviour
# and may be a bug in Keras/TF.
# PlaidML has a bug regarding the clipnorm parameter
# See: https://github.com/plaidml/plaidml/issues/228
# Workaround by simply removing it.
# TODO: Remove this as soon it is fixed in PlaidML.
opt_kwargs["clipnorm"] = 1.0
logger.debug("Optimizer kwargs: %s", opt_kwargs)
return Adam(**opt_kwargs, cpu_mode=self.vram_savings.optimizer_savings)
def converter(self, swap):
""" Converter for autoencoder models """
logger.debug("Getting Converter: (swap: %s)", swap)
side = "a" if swap else "b"
model = self.predictors[side]
if self.predict:
# Must compile the model to be thread safe
model._make_predict_function() # pylint: disable=protected-access
retval = model.predict
logger.debug("Got Converter: %s", retval)
return retval
@property
def iterations(self):
"Get current training iteration number"
return self.state.iterations
def map_models(self, swapped):
""" Map the models for A/B side for swapping """
logger.debug("Map models: (swapped: %s)", swapped)
models_map = {"a": dict(), "b": dict()}
sides = ("a", "b") if not swapped else ("b", "a")
for network in self.networks.values():
if network.side == sides[0]:
models_map["a"][network.type] = network.filename
if network.side == sides[1]:
models_map["b"][network.type] = network.filename
logger.debug("Mapped models: (models_map: %s)", models_map)
return models_map
def log_summary(self):
""" Verbose log the model summaries """
if self.predict:
return
for side in sorted(list(self.predictors.keys())):
logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x))
for name, nnmeta in self.networks.items():
if nnmeta.side is not None and nnmeta.side != side:
continue
logger.verbose("%s:", name.title())
nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x))
def do_snapshot(self):
""" Perform a model snapshot """
logger.debug("Performing snapshot")
self.backup.snapshot_models(self.iterations)
logger.debug("Performed snapshot")
def load_models(self, swapped):
""" Load models from file """
logger.debug("Load model: (swapped: %s)", swapped)
if not self.models_exist and not self.predict:
logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir)
return None
if not self.models_exist and self.predict:
logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir)
exit(0)
if not self.is_legacy:
K.clear_session()
model_mapping = self.map_models(swapped)
for network in self.networks.values():
if not network.side:
is_loaded = network.load()
else:
is_loaded = network.load(fullpath=model_mapping[network.side][network.type])
if not is_loaded:
break
if is_loaded:
logger.info("Loaded model from disk: '%s'", self.model_dir)
return is_loaded
def save_models(self):
""" Backup and save the models """
logger.debug("Backing up and saving models")
save_averages = self.get_save_averages()
backup_func = self.backup.backup_model if self.should_backup(save_averages) else None
if backup_func:
logger.info("Backing up models...")
executor = futures.ThreadPoolExecutor()
save_threads = [executor.submit(network.save, backup_func=backup_func)
for network in self.networks.values()]
save_threads.append(executor.submit(self.state.save, backup_func=backup_func))
futures.wait(save_threads)
# call result() to capture errors
_ = [thread.result() for thread in save_threads]
msg = "[Saved models]"
if save_averages:
lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0],
side.capitalize(),
save_averages[side])
for side in sorted(list(save_averages.keys()))]
msg += " - Average since last save: {}".format(", ".join(lossmsg))
logger.info(msg)
def get_save_averages(self):
""" Return the average loss since the last save iteration and reset historical loss """
logger.debug("Getting save averages")
avgs = dict()
for side, loss in self.history.items():
if not loss:
logger.debug("No loss in self.history: %s", side)
break
avgs[side] = sum(loss) / len(loss)
self.history[side] = list() # Reset historical loss
logger.debug("Average losses since last save: %s", avgs)
return avgs
def should_backup(self, save_averages):
""" Check whether the loss averages for all losses is the lowest that has been seen.
This protects against model corruption by only backing up the model
if any of the loss values have fallen.
TODO This is not a perfect system. If the model corrupts on save_iteration - 1
then model may still backup
"""
backup = True
if not save_averages:
logger.debug("No save averages. Not backing up")
return False
for side, loss in save_averages.items():
if not self.state.lowest_avg_loss.get(side, None):
logger.debug("Setting initial save iteration loss average for '%s': %s",
side, loss)
self.state.lowest_avg_loss[side] = loss
continue
if backup:
# Only run this if backup is true. All losses must have dropped for a valid backup
backup = self.check_loss_drop(side, loss)
logger.debug("Lowest historical save iteration loss average: %s",
self.state.lowest_avg_loss)
if backup: # Update lowest loss values to the state
for side, avg_loss in save_averages.items():
logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
self.state.lowest_avg_loss[side] = avg_loss
logger.debug("Backing up: %s", backup)
return backup
def check_loss_drop(self, side, avg):
""" Check whether total loss has dropped since lowest loss """
if avg < self.state.lowest_avg_loss[side]:
logger.debug("Loss for '%s' has dropped", side)
return True
logger.debug("Loss for '%s' has not dropped", side)
return False
def rename_legacy(self):
""" Legacy Original, LowMem and IAE models had inconsistent naming conventions
Rename them if they are found and update """
legacy_mapping = {"iae": [("IAE_decoder.h5", "iae_decoder.h5"),
("IAE_encoder.h5", "iae_encoder.h5"),
("IAE_inter_A.h5", "iae_intermediate_A.h5"),
("IAE_inter_B.h5", "iae_intermediate_B.h5"),
("IAE_inter_both.h5", "iae_inter.h5")],
"original": [("encoder.h5", "original_encoder.h5"),
("decoder_A.h5", "original_decoder_A.h5"),
("decoder_B.h5", "original_decoder_B.h5"),
("lowmem_encoder.h5", "original_encoder.h5"),
("lowmem_decoder_A.h5", "original_decoder_A.h5"),
("lowmem_decoder_B.h5", "original_decoder_B.h5")]}
if self.name not in legacy_mapping.keys():
return
logger.debug("Renaming legacy files")
set_lowmem = False
updated = False
for old_name, new_name in legacy_mapping[self.name]:
old_path = os.path.join(str(self.model_dir), old_name)
new_path = os.path.join(str(self.model_dir), new_name)
if os.path.exists(old_path) and not os.path.exists(new_path):
logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name)
os.rename(old_path, new_path)
if old_name.startswith("lowmem"):
set_lowmem = True
updated = True
if not updated:
logger.debug("No legacy files to rename")
return
self.is_legacy = True
logger.debug("Creating state file for legacy model")
self.state.inputs = {"face:0": [64, 64, 3]}
self.state.training_size = 256
self.state.config["coverage"] = 62.5
self.state.config["subpixel_upscaling"] = False
self.state.config["reflect_padding"] = False
self.state.config["mask_type"] = None
self.state.config["lowmem"] = False
self.encoder_dim = 1024
if set_lowmem:
logger.debug("Setting encoder_dim and lowmem flag for legacy lowmem model")
self.encoder_dim = 512
self.state.config["lowmem"] = True
self.state.replace_config(self.config_changeable_items)
self.state.save()
class VRAMSavings():
""" VRAM Saving training methods """
def __init__(self, pingpong, optimizer_savings, memory_saving_gradients):
logger.debug("Initializing %s: (pingpong: %s, optimizer_savings: %s, "
"memory_saving_gradients: %s)", self.__class__.__name__,
pingpong, optimizer_savings, memory_saving_gradients)
self.is_plaidml = keras.backend.backend() == "plaidml.keras.backend"
self.pingpong = self.set_pingpong(pingpong)
self.optimizer_savings = self.set_optimizer_savings(optimizer_savings)
self.memory_saving_gradients = self.set_gradient_type(memory_saving_gradients)
logger.debug("Initialized: %s", self.__class__.__name__)
def set_pingpong(self, pingpong):
""" Disable pingpong for plaidML users """
if pingpong and self.is_plaidml:
logger.warning("Pingpong training not supported on plaidML. Disabling")
pingpong = False
logger.debug("pingpong: %s", pingpong)
if pingpong:
logger.info("Using Pingpong Training")
return pingpong
def set_optimizer_savings(self, optimizer_savings):
""" Disable optimizer savings for plaidML users """
if optimizer_savings and self.is_plaidml == "plaidml.keras.backend":
logger.warning("Optimizer Savings not supported on plaidML. Disabling")
optimizer_savings = False
logger.debug("optimizer_savings: %s", optimizer_savings)
if optimizer_savings:
logger.info("Using Optimizer Savings")
return optimizer_savings
def set_gradient_type(self, memory_saving_gradients):
""" Monkeypatch Memory Saving Gradients if requested """
if memory_saving_gradients and self.is_plaidml:
logger.warning("Memory Saving Gradients not supported on plaidML. Disabling")
memory_saving_gradients = False
logger.debug("memory_saving_gradients: %s", memory_saving_gradients)
if memory_saving_gradients:
logger.info("Using Memory Saving Gradients")
from lib.model import memory_saving_gradients
K.__dict__["gradients"] = memory_saving_gradients.gradients_memory
return memory_saving_gradients
class Loss():
""" Holds loss names and functions for an Autoencoder """
def __init__(self, inputs, outputs):
logger.debug("Initializing %s: (inputs: %s, outputs: %s)",
self.__class__.__name__, inputs, outputs)
self.inputs = inputs
self.outputs = outputs
self.names = self.get_loss_names()
self.funcs = self.get_loss_functions()
if len(self.names) > 1:
self.names.insert(0, "total_loss")
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def loss_dict(self):
""" Return the loss dict """
loss_dict = dict(mae=losses.mean_absolute_error,
mse=losses.mean_squared_error,
logcosh=losses.logcosh,
smooth_loss=generalized_loss,
l_inf_norm=l_inf_norm,
ssim=DSSIMObjective(),
gmsd=gmsd_loss,
pixel_gradient_diff=gradient_loss)
return loss_dict
@property
def config(self):
""" Return the global _CONFIG variable """
return _CONFIG
@property
def mask_preprocessing_func(self):
""" The selected pre-processing function for the mask """
retval = None
if self.config.get("mask_blur", False):
retval = gaussian_blur(max(1, self.mask_shape[1] // 32))
logger.debug(retval)
return retval
@property
def selected_loss(self):
""" Return the selected loss function """
retval = self.loss_dict[self.config.get("loss_function", "mae")]
logger.debug(retval)
return retval
@property
def selected_mask_loss(self):
""" Return the selected mask loss function. Currently returns mse
If a processing function has been requested wrap the loss function
in loss wrapper """
loss_func = self.loss_dict["mse"]
func = self.mask_preprocessing_func
logger.debug("loss_func: %s, func: %s", loss_func, func)
retval = mask_loss_wrapper(loss_func, preprocessing_func=func)
return retval
@property
def output_shapes(self):
""" The shapes of the output nodes """
return [K.int_shape(output)[1:] for output in self.outputs]
@property
def mask_input(self):
""" Return the mask input or None """
mask_inputs = [inp for inp in self.inputs if inp.name.startswith("mask")]
if not mask_inputs:
return None
return mask_inputs[0]
@property
def mask_shape(self):
""" Return the mask shape """
if self.mask_input is None:
return None
return K.int_shape(self.mask_input)[1:]
def get_loss_names(self):
""" Return the loss names based on model output """
output_names = [output.name for output in self.outputs]
logger.debug("Model output names: %s", output_names)
loss_names = [name[name.find("/") + 1:name.rfind("/")].replace("_out", "")
for name in output_names]
if not all(name.startswith("face") or name.startswith("mask") for name in loss_names):
# Handle incorrectly named/legacy outputs
logger.debug("Renaming loss names from: %s", loss_names)
loss_names = self.update_loss_names()
loss_names = ["{}_loss".format(name) for name in loss_names]
logger.debug(loss_names)
return loss_names
def update_loss_names(self):
""" Update loss names if named incorrectly or legacy model """
output_types = ["mask" if shape[-1] == 1 else "face" for shape in self.output_shapes]
loss_names = ["{}{}".format(name,
"" if output_types.count(name) == 1 else "_{}".format(idx))
for idx, name in enumerate(output_types)]
logger.debug("Renamed loss names to: %s", loss_names)
return loss_names
def get_loss_functions(self):
""" Set the loss function """
loss_funcs = []
for idx, loss_name in enumerate(self.names):
if loss_name.startswith("mask"):
loss_funcs.append(self.selected_mask_loss)
elif self.mask_input is not None and self.config.get("penalized_mask_loss", False):
face_size = self.output_shapes[idx][1]
mask_size = self.mask_shape[1]
scaling = face_size / mask_size
logger.debug("face_size: %s mask_size: %s, mask_scaling: %s",
face_size, mask_size, scaling)
loss_funcs.append(PenalizedLoss(self.mask_input, self.selected_loss,
mask_scaling=scaling,
preprocessing_func=self.mask_preprocessing_func))
else:
loss_funcs.append(self.selected_loss)
logger.debug("%s: %s", loss_name, loss_funcs[-1])
logger.debug(loss_funcs)
return loss_funcs
class NNMeta():
""" Class to hold a neural network and it's meta data
filename: The full path and filename of the model file for this network.
type: The type of network. For networks that can be swapped
The type should be identical for the corresponding
A and B networks, and should be unique for every A/B pair.
Otherwise the type should be completely unique.
side: A, B or None. Used to identify which networks can
be swapped.
network: Define network to this.
is_output: Set to True to indicate that this network is an output to the Autoencoder
"""
def __init__(self, filename, network_type, side, network, is_output):
logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
"network: %s, is_output: %s", self.__class__.__name__, filename,
network_type, side, network, is_output)
self.filename = filename
self.type = network_type.lower()
self.side = side
self.name = self.set_name()
self.network = network
self.is_output = is_output
self.network.name = self.name
self.config = network.get_config() # For pingpong restore
self.weights = network.get_weights() # For pingpong restore
logger.debug("Initialized %s", self.__class__.__name__)
@property
def output_shapes(self):
""" Return the output shapes from the stored network """
return [K.int_shape(output) for output in self.network.outputs]
def set_name(self):
""" Set the network name """
name = self.type
if self.side:
name += "_{}".format(self.side)
return name
@property
def output_names(self):
""" Return output node names """
output_names = [output.name for output in self.network.outputs]
if self.is_output and not any(name.startswith("face_out") for name in output_names):
# Saved models break if their layer names are changed, so dummy
# in correct output names for legacy models
output_names = self.get_output_names()
return output_names
def get_output_names(self):
""" Return the output names based on number of channels and instances """
output_types = ["mask_out" if K.int_shape(output)[-1] == 1 else "face_out"
for output in self.network.outputs]
output_names = ["{}{}".format(name,
"" if output_types.count(name) == 1 else "_{}".format(idx))
for idx, name in enumerate(output_types)]
logger.debug("Overridden output_names: %s", output_names)
return output_names
def load(self, fullpath=None):
""" Load model """
fullpath = fullpath if fullpath else self.filename
logger.debug("Loading model: '%s'", fullpath)
try:
network = load_model(self.filename, custom_objects=get_custom_objects())
except ValueError as err:
if str(err).lower().startswith("cannot create group in read only mode"):
self.convert_legacy_weights()
return True
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
except OSError as err: # pylint: disable=broad-except
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
self.config = network.get_config()
self.network = network # Update network with saved model
self.network.name = self.name
return True
def save(self, fullpath=None, backup_func=None):
""" Save model """
fullpath = fullpath if fullpath else self.filename
if backup_func:
backup_func(fullpath)
logger.debug("Saving model: '%s'", fullpath)
self.weights = self.network.get_weights()
self.network.save(fullpath)
def convert_legacy_weights(self):
""" Convert legacy weights files to hold the model topology """
logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
self.network.load_weights(self.filename)
self.save(backup_func=None)
self.network.name = self.type
class State():
""" Class to hold the model's current state and autoencoder structure """
def __init__(self, model_dir, model_name, config_changeable_items,
no_logs, pingpong, training_image_size):
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', "
"config_changeable_items: '%s', no_logs: %s, pingpong: %s, "
"training_image_size: '%s'", self.__class__.__name__, model_dir, model_name,
config_changeable_items, no_logs, pingpong, training_image_size)
self.serializer = Serializer.get_serializer("json")
filename = "{}_state.{}".format(model_name, self.serializer.ext)
self.filename = str(model_dir / filename)
self.name = model_name
self.iterations = 0
self.session_iterations = 0
self.training_size = training_image_size
self.sessions = dict()
self.lowest_avg_loss = dict()
self.inputs = dict()
self.config = dict()
self.load(config_changeable_items)
self.session_id = self.new_session_id()
self.create_new_session(no_logs, pingpong)
logger.debug("Initialized %s:", self.__class__.__name__)
@property
def face_shapes(self):
""" Return a list of stored face shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]
@property
def mask_shapes(self):
""" Return a list of stored mask shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]
@property
def loss_names(self):
""" Return the loss names for this session """
return self.sessions[self.session_id]["loss_names"]
@property
def current_session(self):
""" Return the current session dict """
return self.sessions[self.session_id]
@property
def first_run(self):
""" Return True if this is the first run else False """
return self.session_id == 1
def new_session_id(self):
""" Return new session_id """
if not self.sessions:
session_id = 1
else:
session_id = max(int(key) for key in self.sessions.keys()) + 1
logger.debug(session_id)
return session_id
def create_new_session(self, no_logs, pingpong):
""" Create a new session """
logger.debug("Creating new session. id: %s", self.session_id)
self.sessions[self.session_id] = {"timestamp": time.time(),
"no_logs": no_logs,
"pingpong": pingpong,
"loss_names": dict(),
"batchsize": 0,
"iterations": 0}
def add_session_loss_names(self, side, loss_names):
""" Add the session loss names to the sessions dictionary """
logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
self.sessions[self.session_id]["loss_names"][side] = loss_names
def add_session_batchsize(self, batchsize):
""" Add the session batchsize to the sessions dictionary """
logger.debug("Adding session batchsize: %s", batchsize)
self.sessions[self.session_id]["batchsize"] = batchsize
def increment_iterations(self):
""" Increment total and session iterations """
self.iterations += 1
self.sessions[self.session_id]["iterations"] += 1
def load(self, config_changeable_items):
""" Load state file """
logger.debug("Loading State")
try:
with open(self.filename, "rb") as inp:
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
self.name = state.get("name", self.name)
self.sessions = state.get("sessions", dict())
self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
self.iterations = state.get("iterations", 0)
self.training_size = state.get("training_size", 256)
self.inputs = state.get("inputs", dict())
self.config = state.get("config", dict())
logger.debug("Loaded state: %s", state)
self.replace_config(config_changeable_items)
except IOError as err:
logger.warning("No existing state file found. Generating.")
logger.debug("IOError: %s", str(err))
except JSONDecodeError as err:
logger.debug("JSONDecodeError: %s:", str(err))
def save(self, backup_func=None):
""" Save iteration number to state file """
logger.debug("Saving State")
if backup_func:
backup_func(self.filename)
try:
with open(self.filename, "wb") as out:
state = {"name": self.name,
"sessions": self.sessions,
"lowest_avg_loss": self.lowest_avg_loss,
"iterations": self.iterations,
"inputs": self.inputs,
"training_size": self.training_size,
"config": _CONFIG}
state_json = self.serializer.marshal(state)
out.write(state_json.encode("utf-8"))
except IOError as err:
logger.error("Unable to save model state: %s", str(err.strerror))
logger.debug("Saved State")
def replace_config(self, config_changeable_items):
""" Replace the loaded config with the one contained within the state file
Check for any fixed=False parameters changes and log info changes
"""
global _CONFIG # pylint: disable=global-statement
# Add any new items to state config for legacy purposes
for key, val in _CONFIG.items():
if key not in self.config.keys():
logger.info("Adding new config item to state file: '%s': '%s'", key, val)
self.config[key] = val
legacy_update = self.update_legacy_config()
self.update_changed_config_items(config_changeable_items)
logger.debug("Replacing config. Old config: %s", _CONFIG)
_CONFIG = self.config
if legacy_update:
self.save()
logger.debug("Replaced config. New config: %s", _CONFIG)
logger.info("Using configuration saved in state file")
def update_legacy_config(self):
""" Update legacy state config files with the new loss formating
"""
prior = "dssim_loss"
new = "loss_function"
if prior not in self.config:
return False
self.config[new] = "ssim" if self.config[prior] else "mae"
del self.config[prior]
logger.info("Updated config from older dssim format. New config loss function: %s",
self.config[new])
return True
def update_changed_config_items(self, config_changeable_items):
""" Update any parameters which are not fixed and have been changed """
if not config_changeable_items:
logger.debug("No changeable parameters have been updated")
return
for key, val in config_changeable_items.items():
old_val = self.config[key]
if old_val == val:
continue
self.config[key] = val
logger.info("Config item: '%s' has been updated from '%s' to '%s'", key, old_val, val)