#!/usr/bin/env python3 """ Base class for Models. ALL Models should at least inherit from this class When inheriting model_data should be a list of NNMeta objects. See the class for details. """ import logging import os import sys import time from concurrent import futures import keras from keras import losses from keras import backend as K from keras.layers import Input from keras.models import load_model, Model from keras.utils import get_custom_objects, multi_gpu_model from lib.serializer import get_serializer from lib.model.backup_restore import Backup from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask_loss_wrapper, generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur) from lib.model.nn_blocks import NNBlocks from lib.model.optimizers import Adam from lib.utils import deprecation_warning, FaceswapError from plugins.train._config import Config logger = logging.getLogger(__name__) # pylint: disable=invalid-name _CONFIG = None class ModelBase(): """ Base class that all models should inherit from """ def __init__(self, model_dir, gpus=1, configfile=None, snapshot_interval=0, no_logs=False, warp_to_landmarks=False, augment_color=True, no_flip=False, training_image_size=256, alignments_paths=None, preview_scale=100, input_shape=None, encoder_dim=None, trainer="original", pingpong=False, memory_saving_gradients=False, optimizer_savings=False, predict=False): logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, " "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: " "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, " "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, " "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, " "predict: %s)", self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval, no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size, alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong, memory_saving_gradients, optimizer_savings, predict) self.predict = predict self.model_dir = model_dir self.vram_savings = VRAMSavings(pingpong, optimizer_savings, memory_saving_gradients) self.backup = Backup(self.model_dir, self.name) self.gpus = gpus self.configfile = configfile self.input_shape = input_shape self.encoder_dim = encoder_dim self.trainer = trainer self.load_config() # Load config if plugin has not already referenced it self.state = State(self.model_dir, self.name, self.config_changeable_items, no_logs, self.vram_savings.pingpong, training_image_size) self.blocks = NNBlocks(use_icnr_init=self.config["icnr_init"], use_convaware_init=self.config["conv_aware_init"], use_reflect_padding=self.config["reflect_padding"], first_run=self.state.first_run) self.is_legacy = False self.rename_legacy() self.load_state_info() self.networks = dict() # Networks for the model self.predictors = dict() # Predictors for model self.history = dict() # Loss history per save iteration) # Training information specific to the model should be placed in this # dict for reference by the trainer. self.training_opts = {"alignments": alignments_paths, "preview_scaling": preview_scale / 100, "warp_to_landmarks": warp_to_landmarks, "augment_color": augment_color, "no_flip": no_flip, "pingpong": self.vram_savings.pingpong, "snapshot_interval": snapshot_interval, "training_size": self.state.training_size, "no_logs": self.state.current_session["no_logs"], "coverage_ratio": self.calculate_coverage_ratio(), "mask_type": self.config["mask_type"], "mask_blur_kernel": self.config["mask_blur_kernel"], "mask_threshold": self.config["mask_threshold"], "learn_mask": (self.config["learn_mask"] and self.config["mask_type"] is not None), "penalized_mask_loss": (self.config["penalized_mask_loss"] and self.config["mask_type"] is not None)} logger.debug("training_opts: %s", self.training_opts) if self.multiple_models_in_folder: deprecation_warning("Support for multiple model types within the same folder", additional_info="Please split each model into separate folders to " "avoid issues in future.") self.build() logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) @property def config_section(self): """ The section name for loading config """ retval = ".".join(self.__module__.split(".")[-2:]) logger.debug(retval) return retval @property def config(self): """ Return config dict for current plugin """ global _CONFIG # pylint: disable=global-statement if not _CONFIG: model_name = self.config_section logger.debug("Loading config for: %s", model_name) _CONFIG = Config(model_name, configfile=self.configfile).config_dict return _CONFIG @property def config_changeable_items(self): """ Return the dict of config items that can be updated after the model has been created """ return Config(self.config_section, configfile=self.configfile).changeable_items @property def name(self): """ Set the model name based on the subclass """ basename = os.path.basename(sys.modules[self.__module__].__file__) retval = os.path.splitext(basename)[0].lower() logger.debug("model name: '%s'", retval) return retval @property def models_exist(self): """ Return if all files exist and clear session """ retval = all([os.path.isfile(model.filename) for model in self.networks.values()]) logger.debug("Pre-existing models exist: %s", retval) return retval @property def multiple_models_in_folder(self): """ Return true if there are multiple model types in the same folder, else false """ model_files = [fname for fname in os.listdir(str(self.model_dir)) if fname.endswith(".h5")] retval = False if not model_files else os.path.commonprefix(model_files) == "" logger.debug("model_files: %s, retval: %s", model_files, retval) return retval @property def output_shapes(self): """ Return the output shapes from the main AutoEncoder """ out = list() for predictor in self.predictors.values(): out.extend([K.int_shape(output)[-3:] for output in predictor.outputs]) break # Only get output from one autoencoder. Shapes are the same return [tuple(shape) for shape in out] @property def output_shape(self): """ The output shape of the model (shape of largest face output) """ return self.output_shapes[self.largest_face_index] @property def largest_face_index(self): """ Return the index from model.outputs of the largest face Required for multi-output model prediction. The largest face is assumed to be the final output """ sizes = [shape[1] for shape in self.output_shapes if shape[2] == 3] if not sizes: return None max_face = max(sizes) retval = [idx for idx, shape in enumerate(self.output_shapes) if shape[1] == max_face and shape[2] == 3][0] logger.debug(retval) return retval @property def largest_mask_index(self): """ Return the index from model.outputs of the largest mask Required for multi-output model prediction. The largest face is assumed to be the final output """ sizes = [shape[1] for shape in self.output_shapes if shape[2] == 1] if not sizes: return None max_mask = max(sizes) retval = [idx for idx, shape in enumerate(self.output_shapes) if shape[1] == max_mask and shape[2] == 1][0] logger.debug(retval) return retval @property def feed_mask(self): """ bool: ``True`` if the model expects a mask to be fed into input otherwise ``False`` """ return self.config["mask_type"] is not None and (self.config["learn_mask"] or self.config["penalized_mask_loss"]) def load_config(self): """ Load the global config for reference in self.config """ global _CONFIG # pylint: disable=global-statement if not _CONFIG: model_name = self.config_section logger.debug("Loading config for: %s", model_name) _CONFIG = Config(model_name, configfile=self.configfile).config_dict def calculate_coverage_ratio(self): """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """ coverage_ratio = self.config.get("coverage", 62.5) / 100 logger.debug("Requested coverage_ratio: %s", coverage_ratio) cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2 coverage_ratio = cropped_size / self.state.training_size logger.debug("Final coverage_ratio: %s", coverage_ratio) return coverage_ratio def build(self): """ Build the model. Override for custom build methods """ self.add_networks() self.load_models(swapped=False) inputs = self.get_inputs() try: self.build_autoencoders(inputs) except ValueError as err: if "must be from the same graph" in str(err).lower(): msg = ("There was an error loading saved weights. This is most likely due to " "model corruption during a previous save." "\nYou should restore weights from a snapshot or from backup files. " "You can use the 'Restore' Tool to restore from backup.") raise FaceswapError(msg) from err if "multi_gpu_model" in str(err).lower(): raise FaceswapError(str(err)) from err raise err self.log_summary() self.compile_predictors(initialize=True) def get_inputs(self): """ Return the inputs for the model """ logger.debug("Getting inputs") inputs = [Input(shape=self.input_shape, name="face_in")] output_network = [network for network in self.networks.values() if network.is_output][0] if self.feed_mask: # TODO penalized mask doesn't have a mask output, so we can't use output shapes # mask should always be last output..this needs to be a rule mask_shape = output_network.output_shapes[-1] inputs.append(Input(shape=(mask_shape[1:-1] + (1,)), name="mask_in")) logger.debug("Got inputs: %s", inputs) return inputs def build_autoencoders(self, inputs): """ Override for Model Specific autoencoder builds Inputs is defined in self.get_inputs() and is standardized for all models if will generally be in the order: [face (the input for image), mask (the input for mask if it is used)] """ raise NotImplementedError def add_networks(self): """ Override to add neural networks """ raise NotImplementedError def load_state_info(self): """ Load the input shape from state file if it exists """ logger.debug("Loading Input Shape from State file") if not self.state.inputs: logger.debug("No input shapes saved. Using model config") return if not self.state.face_shapes: logger.warning("Input shapes stored in State file, but no matches for 'face'." "Using model config") return input_shape = self.state.face_shapes[0] logger.debug("Setting input shape from state file: %s", input_shape) self.input_shape = input_shape def add_network(self, network_type, side, network, is_output=False): """ Add a NNMeta object """ logger.debug("network_type: '%s', side: '%s', network: '%s', is_output: %s", network_type, side, network, is_output) filename = "{}_{}".format(self.name, network_type.lower()) name = network_type.lower() if side: side = side.lower() filename += "_{}".format(side.upper()) name += "_{}".format(side) filename += ".h5" logger.debug("name: '%s', filename: '%s'", name, filename) self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network, is_output) def add_predictor(self, side, model): """ Add a predictor to the predictors dictionary """ logger.debug("Adding predictor: (side: '%s', model: %s)", side, model) if self.gpus > 1: logger.debug("Converting to multi-gpu: side %s", side) model = multi_gpu_model(model, self.gpus) self.predictors[side] = model if not self.state.inputs: self.store_input_shapes(model) def store_input_shapes(self, model): """ Store the input and output shapes to state """ logger.debug("Adding input shapes to state for model") inputs = {tensor.name: K.int_shape(tensor)[-3:] for tensor in model.inputs} if not any(inp for inp in inputs.keys() if inp.startswith("face")): raise ValueError("No input named 'face' was found. Check your input naming. " "Current input names: {}".format(inputs)) # Make sure they are all ints so that it can be json serialized inputs = {key: tuple(int(i) for i in val) for key, val in inputs.items()} self.state.inputs = inputs logger.debug("Added input shapes: %s", self.state.inputs) def reset_pingpong(self): """ Reset the models for pingpong training """ logger.debug("Resetting models") # Clear models and graph self.predictors = dict() K.clear_session() # Load Models for current training run for model in self.networks.values(): model.network = Model.from_config(model.config) model.network.set_weights(model.weights) inputs = self.get_inputs() self.build_autoencoders(inputs) self.compile_predictors(initialize=False) logger.debug("Reset models") def compile_predictors(self, initialize=True): """ Compile the predictors """ logger.debug("Compiling Predictors") learning_rate = self.config.get("learning_rate", 5e-5) optimizer = self.get_optimizer(lr=learning_rate, beta_1=0.5, beta_2=0.999) for side, model in self.predictors.items(): loss = Loss(model.inputs, model.outputs) model.compile(optimizer=optimizer, loss=loss.funcs) if initialize: self.state.add_session_loss_names(side, loss.names) self.history[side] = list() logger.debug("Compiled Predictors. Losses: %s", loss.names) def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=invalid-name """ Build and return Optimizer """ opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2) if (self.config.get("clipnorm", False) and keras.backend.backend() != "plaidml.keras.backend"): # NB: Clip-norm is ballooning VRAM usage, which is not expected behavior # and may be a bug in Keras/Tensorflow. # PlaidML has a bug regarding the clip-norm parameter # See: https://github.com/plaidml/plaidml/issues/228 # Workaround by simply removing it. # TODO: Remove this as soon it is fixed in PlaidML. opt_kwargs["clipnorm"] = 1.0 logger.debug("Optimizer kwargs: %s", opt_kwargs) return Adam(**opt_kwargs, cpu_mode=self.vram_savings.optimizer_savings) def converter(self, swap): """ Converter for autoencoder models """ logger.debug("Getting Converter: (swap: %s)", swap) side = "a" if swap else "b" model = self.predictors[side] if self.predict: # Must compile the model to be thread safe model._make_predict_function() # pylint: disable=protected-access retval = model.predict logger.debug("Got Converter: %s", retval) return retval @property def iterations(self): "Get current training iteration number" return self.state.iterations def map_models(self, swapped): """ Map the models for A/B side for swapping """ logger.debug("Map models: (swapped: %s)", swapped) models_map = {"a": dict(), "b": dict()} sides = ("a", "b") if not swapped else ("b", "a") for network in self.networks.values(): if network.side == sides[0]: models_map["a"][network.type] = network.filename if network.side == sides[1]: models_map["b"][network.type] = network.filename logger.debug("Mapped models: (models_map: %s)", models_map) return models_map def log_summary(self): """ Verbose log the model summaries """ if self.predict: return for side in sorted(list(self.predictors.keys())): logger.verbose("[%s %s Summary]:", self.name.title(), side.upper()) self.predictors[side].summary(print_fn=lambda x: logger.verbose("%s", x)) for name, nnmeta in self.networks.items(): if nnmeta.side is not None and nnmeta.side != side: continue logger.verbose("%s:", name.title()) nnmeta.network.summary(print_fn=lambda x: logger.verbose("%s", x)) def do_snapshot(self): """ Perform a model snapshot """ logger.debug("Performing snapshot") self.backup.snapshot_models(self.iterations) logger.debug("Performed snapshot") def load_models(self, swapped): """ Load models from file """ logger.debug("Load model: (swapped: %s)", swapped) if not self.models_exist and not self.predict: logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir) return None if not self.models_exist and self.predict: logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir) exit(0) if not self.is_legacy or not self.predict: K.clear_session() model_mapping = self.map_models(swapped) for network in self.networks.values(): if not network.side: is_loaded = network.load() else: is_loaded = network.load(fullpath=model_mapping[network.side][network.type]) if not is_loaded: break if is_loaded: logger.info("Loaded model from disk: '%s'", self.model_dir) return is_loaded def save_models(self): """ Backup and save the models """ logger.debug("Backing up and saving models") # Insert a new line to avoid spamming the same row as loss output print("") save_averages = self.get_save_averages() backup_func = self.backup.backup_model if self.should_backup(save_averages) else None if backup_func: logger.info("Backing up models...") executor = futures.ThreadPoolExecutor() save_threads = [executor.submit(network.save, backup_func=backup_func) for network in self.networks.values()] save_threads.append(executor.submit(self.state.save, backup_func=backup_func)) futures.wait(save_threads) # call result() to capture errors _ = [thread.result() for thread in save_threads] msg = "[Saved models]" if save_averages: lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0], side.capitalize(), save_averages[side]) for side in sorted(list(save_averages.keys()))] msg += " - Average since last save: {}".format(", ".join(lossmsg)) logger.info(msg) def get_save_averages(self): """ Return the average loss since the last save iteration and reset historical loss """ logger.debug("Getting save averages") avgs = dict() for side, loss in self.history.items(): if not loss: logger.debug("No loss in self.history: %s", side) break avgs[side] = sum(loss) / len(loss) self.history[side] = list() # Reset historical loss logger.debug("Average losses since last save: %s", avgs) return avgs def should_backup(self, save_averages): """ Check whether the loss averages for all losses is the lowest that has been seen. This protects against model corruption by only backing up the model if any of the loss values have fallen. TODO This is not a perfect system. If the model corrupts on save_iteration - 1 then model may still backup """ backup = True if not save_averages: logger.debug("No save averages. Not backing up") return False for side, loss in save_averages.items(): if not self.state.lowest_avg_loss.get(side, None): logger.debug("Setting initial save iteration loss average for '%s': %s", side, loss) self.state.lowest_avg_loss[side] = loss continue if backup: # Only run this if backup is true. All losses must have dropped for a valid backup backup = self.check_loss_drop(side, loss) logger.debug("Lowest historical save iteration loss average: %s", self.state.lowest_avg_loss) if backup: # Update lowest loss values to the state for side, avg_loss in save_averages.items(): logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss) self.state.lowest_avg_loss[side] = avg_loss logger.debug("Backing up: %s", backup) return backup def check_loss_drop(self, side, avg): """ Check whether total loss has dropped since lowest loss """ if avg < self.state.lowest_avg_loss[side]: logger.debug("Loss for '%s' has dropped", side) return True logger.debug("Loss for '%s' has not dropped", side) return False def rename_legacy(self): """ Legacy Original, LowMem and IAE models had inconsistent naming conventions Rename them if they are found and update """ legacy_mapping = {"iae": [("IAE_decoder.h5", "iae_decoder.h5"), ("IAE_encoder.h5", "iae_encoder.h5"), ("IAE_inter_A.h5", "iae_intermediate_A.h5"), ("IAE_inter_B.h5", "iae_intermediate_B.h5"), ("IAE_inter_both.h5", "iae_inter.h5")], "original": [("encoder.h5", "original_encoder.h5"), ("decoder_A.h5", "original_decoder_A.h5"), ("decoder_B.h5", "original_decoder_B.h5"), ("lowmem_encoder.h5", "original_encoder.h5"), ("lowmem_decoder_A.h5", "original_decoder_A.h5"), ("lowmem_decoder_B.h5", "original_decoder_B.h5")]} if self.name not in legacy_mapping.keys(): return logger.debug("Renaming legacy files") set_lowmem = False updated = False for old_name, new_name in legacy_mapping[self.name]: old_path = os.path.join(str(self.model_dir), old_name) new_path = os.path.join(str(self.model_dir), new_name) if os.path.exists(old_path) and not os.path.exists(new_path): logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name) os.rename(old_path, new_path) if old_name.startswith("lowmem"): set_lowmem = True updated = True if not updated: logger.debug("No legacy files to rename") return self.is_legacy = True logger.debug("Creating state file for legacy model") self.state.inputs = {"face:0": [64, 64, 3]} self.state.training_size = 256 self.state.config["coverage"] = 62.5 self.state.config["reflect_padding"] = False self.state.config["mask_type"] = None self.state.config["mask_blur_kernel"] = 3 self.state.config["mask_threshold"] = 4 self.state.config["learn_mask"] = False self.state.config["lowmem"] = False self.encoder_dim = 1024 if set_lowmem: logger.debug("Setting encoder_dim and lowmem flag for legacy lowmem model") self.encoder_dim = 512 self.state.config["lowmem"] = True self.state.replace_config(self.config_changeable_items) self.state.save() class VRAMSavings(): """ VRAM Saving training methods """ def __init__(self, pingpong, optimizer_savings, memory_saving_gradients): logger.debug("Initializing %s: (pingpong: %s, optimizer_savings: %s, " "memory_saving_gradients: %s)", self.__class__.__name__, pingpong, optimizer_savings, memory_saving_gradients) self.is_plaidml = keras.backend.backend() == "plaidml.keras.backend" self.pingpong = self.set_pingpong(pingpong) self.optimizer_savings = self.set_optimizer_savings(optimizer_savings) self.memory_saving_gradients = self.set_gradient_type(memory_saving_gradients) logger.debug("Initialized: %s", self.__class__.__name__) def set_pingpong(self, pingpong): """ Disable pingpong for plaidML users """ if pingpong and self.is_plaidml: logger.warning("Pingpong training not supported on plaidML. Disabling") pingpong = False logger.debug("pingpong: %s", pingpong) if pingpong: logger.info("Using Pingpong Training") return pingpong def set_optimizer_savings(self, optimizer_savings): """ Disable optimizer savings for plaidML users """ if optimizer_savings and self.is_plaidml == "plaidml.keras.backend": logger.warning("Optimizer Savings not supported on plaidML. Disabling") optimizer_savings = False logger.debug("optimizer_savings: %s", optimizer_savings) if optimizer_savings: logger.info("Using Optimizer Savings") return optimizer_savings def set_gradient_type(self, memory_saving_gradients): """ Monkey-patch Memory Saving Gradients if requested """ if memory_saving_gradients and self.is_plaidml: logger.warning("Memory Saving Gradients not supported on plaidML. Disabling") memory_saving_gradients = False logger.debug("memory_saving_gradients: %s", memory_saving_gradients) if memory_saving_gradients: logger.info("Using Memory Saving Gradients") from lib.model import memory_saving_gradients K.__dict__["gradients"] = memory_saving_gradients.gradients_memory return memory_saving_gradients class Loss(): """ Holds loss names and functions for an Autoencoder """ def __init__(self, inputs, outputs): logger.debug("Initializing %s: (inputs: %s, outputs: %s)", self.__class__.__name__, inputs, outputs) self.inputs = inputs self.outputs = outputs self.names = self.get_loss_names() self.funcs = self.get_loss_functions() if len(self.names) > 1: self.names.insert(0, "total_loss") logger.debug("Initialized: %s", self.__class__.__name__) @property def loss_dict(self): """ Return the loss dict """ loss_dict = dict(mae=losses.mean_absolute_error, mse=losses.mean_squared_error, logcosh=losses.logcosh, smooth_loss=generalized_loss, l_inf_norm=l_inf_norm, ssim=DSSIMObjective(), gmsd=gmsd_loss, pixel_gradient_diff=gradient_loss) return loss_dict @property def config(self): """ Return the global _CONFIG variable """ return _CONFIG @property def mask_preprocessing_func(self): """ The selected pre-processing function for the mask """ retval = None if self.config.get("mask_blur", False): retval = gaussian_blur(max(1, self.mask_shape[1] // 32)) logger.debug(retval) return retval @property def selected_loss(self): """ Return the selected loss function """ retval = self.loss_dict[self.config.get("loss_function", "mae")] logger.debug(retval) return retval @property def selected_mask_loss(self): """ Return the selected mask loss function. Currently returns mse If a processing function has been requested wrap the loss function in loss wrapper """ loss_func = self.loss_dict["mse"] func = self.mask_preprocessing_func logger.debug("loss_func: %s, func: %s", loss_func, func) retval = mask_loss_wrapper(loss_func, preprocessing_func=func) return retval @property def output_shapes(self): """ The shapes of the output nodes """ return [K.int_shape(output)[1:] for output in self.outputs] @property def mask_input(self): """ Return the mask input or None """ mask_inputs = [inp for inp in self.inputs if inp.name.startswith("mask")] if not mask_inputs: return None return mask_inputs[0] @property def mask_shape(self): """ Return the mask shape """ if self.mask_input is None: return None return K.int_shape(self.mask_input)[1:] def get_loss_names(self): """ Return the loss names based on model output """ output_names = [output.name for output in self.outputs] logger.debug("Model output names: %s", output_names) loss_names = [name[name.find("/") + 1:name.rfind("/")].replace("_out", "") for name in output_names] if not all(name.startswith("face") or name.startswith("mask") for name in loss_names): # Handle incorrectly named/legacy outputs logger.debug("Renaming loss names from: %s", loss_names) loss_names = self.update_loss_names() loss_names = ["{}_loss".format(name) for name in loss_names] logger.debug(loss_names) return loss_names def update_loss_names(self): """ Update loss names if named incorrectly or legacy model """ output_types = ["mask" if shape[-1] == 1 else "face" for shape in self.output_shapes] loss_names = ["{}{}".format(name, "" if output_types.count(name) == 1 else "_{}".format(idx)) for idx, name in enumerate(output_types)] logger.debug("Renamed loss names to: %s", loss_names) return loss_names def get_loss_functions(self): """ Set the loss function """ loss_funcs = [] for idx, loss_name in enumerate(self.names): if loss_name.startswith("mask"): loss_funcs.append(self.selected_mask_loss) elif self.config["penalized_mask_loss"] and self.config["mask_type"] is not None: face_size = self.output_shapes[idx][1] mask_size = self.mask_shape[1] scaling = face_size / mask_size logger.debug("face_size: %s mask_size: %s, mask_scaling: %s", face_size, mask_size, scaling) loss_funcs.append(PenalizedLoss(self.mask_input, self.selected_loss, mask_scaling=scaling, preprocessing_func=self.mask_preprocessing_func)) else: loss_funcs.append(self.selected_loss) logger.debug("%s: %s", loss_name, loss_funcs[-1]) logger.debug(loss_funcs) return loss_funcs class NNMeta(): """ Class to hold a neural network and it's meta data filename: The full path and filename of the model file for this network. type: The type of network. For networks that can be swapped The type should be identical for the corresponding A and B networks, and should be unique for every A/B pair. Otherwise the type should be completely unique. side: A, B or None. Used to identify which networks can be swapped. network: Define network to this. is_output: Set to True to indicate that this network is an output to the Autoencoder """ def __init__(self, filename, network_type, side, network, is_output): logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', " "network: %s, is_output: %s", self.__class__.__name__, filename, network_type, side, network, is_output) self.filename = filename self.type = network_type.lower() self.side = side self.name = self.set_name() self.network = network self.is_output = is_output self.network.name = self.name self.config = network.get_config() # For pingpong restore self.weights = network.get_weights() # For pingpong restore logger.debug("Initialized %s", self.__class__.__name__) @property def output_shapes(self): """ Return the output shapes from the stored network """ return [K.int_shape(output) for output in self.network.outputs] def set_name(self): """ Set the network name """ name = self.type if self.side: name += "_{}".format(self.side) return name @property def output_names(self): """ Return output node names """ output_names = [output.name for output in self.network.outputs] if self.is_output and not any(name.startswith("face_out") for name in output_names): # Saved models break if their layer names are changed, so dummy # in correct output names for legacy models output_names = self.get_output_names() return output_names def get_output_names(self): """ Return the output names based on number of channels and instances """ output_types = ["mask_out" if K.int_shape(output)[-1] == 1 else "face_out" for output in self.network.outputs] output_names = ["{}{}".format(name, "" if output_types.count(name) == 1 else "_{}".format(idx)) for idx, name in enumerate(output_types)] logger.debug("Overridden output_names: %s", output_names) return output_names def load(self, fullpath=None): """ Load model """ fullpath = fullpath if fullpath else self.filename logger.debug("Loading model: '%s'", fullpath) try: network = load_model(self.filename, custom_objects=get_custom_objects()) except ValueError as err: if str(err).lower().startswith("cannot create group in read only mode"): self.convert_legacy_weights() return True logger.warning("Failed loading existing training data. Generating new models") logger.debug("Exception: %s", str(err)) return False except OSError as err: # pylint: disable=broad-except logger.warning("Failed loading existing training data. Generating new models") logger.debug("Exception: %s", str(err)) return False self.config = network.get_config() self.network = network # Update network with saved model self.network.name = self.name return True def save(self, fullpath=None, backup_func=None): """ Save model """ fullpath = fullpath if fullpath else self.filename if backup_func: backup_func(fullpath) logger.debug("Saving model: '%s'", fullpath) self.weights = self.network.get_weights() self.network.save(fullpath) def convert_legacy_weights(self): """ Convert legacy weights files to hold the model topology """ logger.info("Adding model topology to legacy weights file: '%s'", self.filename) self.network.load_weights(self.filename) self.save(backup_func=None) self.network.name = self.type class State(): """ Class to hold the model's current state and autoencoder structure """ def __init__(self, model_dir, model_name, config_changeable_items, no_logs, pingpong, training_image_size): logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', " "config_changeable_items: '%s', no_logs: %s, pingpong: %s, " "training_image_size: '%s'", self.__class__.__name__, model_dir, model_name, config_changeable_items, no_logs, pingpong, training_image_size) self.serializer = get_serializer("json") filename = "{}_state.{}".format(model_name, self.serializer.file_extension) self.filename = str(model_dir / filename) self.name = model_name self.iterations = 0 self.session_iterations = 0 self.training_size = training_image_size self.sessions = dict() self.lowest_avg_loss = dict() self.inputs = dict() self.config = dict() self.load(config_changeable_items) self.session_id = self.new_session_id() self.create_new_session(no_logs, pingpong, config_changeable_items) logger.debug("Initialized %s:", self.__class__.__name__) @property def face_shapes(self): """ Return a list of stored face shape inputs """ return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")] @property def mask_shapes(self): """ Return a list of stored mask shape inputs """ return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")] @property def loss_names(self): """ Return the loss names for this session """ return self.sessions[self.session_id]["loss_names"] @property def current_session(self): """ Return the current session dict """ return self.sessions[self.session_id] @property def first_run(self): """ Return True if this is the first run else False """ return self.session_id == 1 def new_session_id(self): """ Return new session_id """ if not self.sessions: session_id = 1 else: session_id = max(int(key) for key in self.sessions.keys()) + 1 logger.debug(session_id) return session_id def create_new_session(self, no_logs, pingpong, config_changeable_items): """ Create a new session """ logger.debug("Creating new session. id: %s", self.session_id) self.sessions[self.session_id] = {"timestamp": time.time(), "no_logs": no_logs, "pingpong": pingpong, "loss_names": dict(), "batchsize": 0, "iterations": 0, "config": config_changeable_items} def add_session_loss_names(self, side, loss_names): """ Add the session loss names to the sessions dictionary """ logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names) self.sessions[self.session_id]["loss_names"][side] = loss_names def add_session_batchsize(self, batchsize): """ Add the session batchsize to the sessions dictionary """ logger.debug("Adding session batchsize: %s", batchsize) self.sessions[self.session_id]["batchsize"] = batchsize def increment_iterations(self): """ Increment total and session iterations """ self.iterations += 1 self.sessions[self.session_id]["iterations"] += 1 def load(self, config_changeable_items): """ Load state file """ logger.debug("Loading State") if not os.path.exists(self.filename): logger.info("No existing state file found. Generating.") return state = self.serializer.load(self.filename) self.name = state.get("name", self.name) self.sessions = state.get("sessions", dict()) self.lowest_avg_loss = state.get("lowest_avg_loss", dict()) self.iterations = state.get("iterations", 0) self.training_size = state.get("training_size", 256) self.inputs = state.get("inputs", dict()) self.config = state.get("config", dict()) logger.debug("Loaded state: %s", state) self.replace_config(config_changeable_items) def save(self, backup_func=None): """ Save iteration number to state file """ logger.debug("Saving State") if backup_func: backup_func(self.filename) state = {"name": self.name, "sessions": self.sessions, "lowest_avg_loss": self.lowest_avg_loss, "iterations": self.iterations, "inputs": self.inputs, "training_size": self.training_size, "config": _CONFIG} self.serializer.save(self.filename, state) logger.debug("Saved State") def replace_config(self, config_changeable_items): """ Replace the loaded config with the one contained within the state file Check for any fixed=False parameters changes and log info changes """ global _CONFIG # pylint: disable=global-statement legacy_update = self._update_legacy_config() # Add any new items to state config for legacy purposes for key, val in _CONFIG.items(): if key not in self.config.keys(): logger.info("Adding new config item to state file: '%s': '%s'", key, val) self.config[key] = val self.update_changed_config_items(config_changeable_items) logger.debug("Replacing config. Old config: %s", _CONFIG) _CONFIG = self.config if legacy_update: self.save() logger.debug("Replaced config. New config: %s", _CONFIG) logger.info("Using configuration saved in state file") def _update_legacy_config(self): """ Legacy updates for new config additions. When new config items are added to the Faceswap code, existing model state files need to be updated to handle these new items. Current existing legacy update items: * loss - If old `dssim_loss` is ``true`` set new `loss_function` to `ssim` otherwise set it to `mae`. Remove old `dssim_loss` item * masks - If `learn_mask` does not exist then it is set to ``True`` if `mask_type` is not ``None`` otherwise it is set to ``False``. * masks type - Replace removed masks 'dfl_full' and 'facehull' with `components` mask Returns ------- bool ``True`` if legacy items exist and state file has been updated, otherwise ``False`` """ logger.debug("Checking for legacy state file update") priors = ["dssim_loss", "mask_type", "mask_type"] new_items = ["loss_function", "learn_mask", "mask_type"] updated = False for old, new in zip(priors, new_items): if old not in self.config: logger.debug("Legacy item '%s' not in config. Skipping update", old) continue # dssim_loss > loss_function if old == "dssim_loss": self.config[new] = "ssim" if self.config[old] else "mae" del self.config[old] updated = True logger.info("Updated config from legacy dssim format. New config loss " "function: '%s'", self.config[new]) continue # Add learn mask option and set to True if model has "penalized_mask_loss" specified if old == "mask_type" and new == "learn_mask" and new not in self.config: self.config[new] = self.config["mask_type"] is not None updated = True logger.info("Added new 'learn_mask' config item for this model. Value set to: %s", self.config[new]) continue # Replace removed masks with most similar equivalent if old == "mask_type" and new == "mask_type" and self.config[old] in ("facehull", "dfl_full"): old_mask = self.config[old] self.config[new] = "components" updated = True logger.info("Updated 'mask_type' from '%s' to '%s' for this model", old_mask, self.config[new]) logger.debug("State file updated for legacy config: %s", updated) return updated def update_changed_config_items(self, config_changeable_items): """ Update any parameters which are not fixed and have been changed """ if not config_changeable_items: logger.debug("No changeable parameters have been updated") return for key, val in config_changeable_items.items(): old_val = self.config[key] if old_val == val: continue self.config[key] = val logger.info("Config item: '%s' has been updated from '%s' to '%s'", key, old_val, val)