mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* model_refactor (#571) * original model to new structure * IAE model to new structure * OriginalHiRes to new structure * Fix trainer for different resolutions * Initial config implementation * Configparse library added * improved training data loader * dfaker model working * Add logging to training functions * Non blocking input for cli training * Add error handling to threads. Add non-mp queues to queue_handler * Improved Model Building and NNMeta * refactor lib/models * training refactor. DFL H128 model Implementation * Dfaker - use hashes * Move timelapse. Remove perceptual loss arg * Update INSTALL.md. Add logger formatting. Update Dfaker training * DFL h128 partially ported * Add mask to dfaker (#573) * Remove old models. Add mask to dfaker * dfl mask. Make masks selectable in config (#575) * DFL H128 Mask. Mask type selectable in config. * remove gan_v2_2 * Creating Input Size config for models Creating Input Size config for models Will be used downstream in converters. Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes) * Add mask loss options to config * MTCNN options to config.ini. Remove GAN config. Update USAGE.md * Add sliders for numerical values in GUI * Add config plugins menu to gui. Validate config * Only backup model if loss has dropped. Get training working again * bugfixes * Standardise loss printing * GUI idle cpu fixes. Graph loss fix. * mutli-gpu logging bugfix * Merge branch 'staging' into train_refactor * backup state file * Crash protection: Only backup if both total losses have dropped * Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes) * Load and save model structure with weights * Slight code update * Improve config loader. Add subpixel opt to all models. Config to state * Show samples... wrong input * Remove AE topology. Add input/output shapes to State * Port original_villain (birb/VillainGuy) model to faceswap * Add plugin info to GUI config pages * Load input shape from state. IAE Config options. * Fix transform_kwargs. Coverage to ratio. Bugfix mask detection * Suppress keras userwarnings. Automate zoom. Coverage_ratio to model def. * Consolidation of converters & refactor (#574) * Consolidation of converters & refactor Initial Upload of alpha Items - consolidate convert_mased & convert_adjust into one converter -add average color adjust to convert_masked -allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size -allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size -eliminate redundant type conversions to avoid multiple round-off errors -refactor loops for vectorization/speed -reorganize for clarity & style changes TODO - bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now - issues with mask border giving black ring at zero erosion .. investigate - remove GAN ?? - test enlargment factors of umeyama standard face .. match to coverage factor - make enlargment factor a model parameter - remove convert_adjusted and referencing code when finished * Update Convert_Masked.py default blur size of 2 to match original... description of enlargement tests breakout matrxi scaling into def * Enlargment scale as a cli parameter * Update cli.py * dynamic interpolation algorithm Compute x & y scale factors from the affine matrix on the fly by QR decomp. Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image * input size input size from config * fix issues with <1.0 erosion * Update convert.py * Update Convert_Adjust.py more work on the way to merginf * Clean up help note on sharpen * cleanup seamless * Delete Convert_Adjust.py * Update umeyama.py * Update training_data.py * swapping * segmentation stub * changes to convert.str * Update masked.py * Backwards compatibility fix for models Get converter running * Convert: Move masks to class. bugfix blur_size some linting * mask fix * convert fixes - missing facehull_rect re-added - coverage to % - corrected coverage logic - cleanup of gui option ordering * Update cli.py * default for blur * Update masked.py * added preliminary low_mem version of OriginalHighRes model plugin * Code cleanup, minor fixes * Update masked.py * Update masked.py * Add dfl mask to convert * histogram fix & seamless location * update * revert * bugfix: Load actual configuration in gui * Standardize nn_blocks * Update cli.py * Minor code amends * Fix Original HiRes model * Add masks to preview output for mask trainers refactor trainer.__base.py * Masked trainers converter support * convert bugfix * Bugfix: Converter for masked (dfl/dfaker) trainers * Additional Losses (#592) * initial upload * Delete blur.py * default initializer = He instead of Glorot (#588) * Allow kernel_initializer to be overridable * Add ICNR Initializer option for upscale on all models. * Hopefully fixes RSoDs with original-highres model plugin * remove debug line * Original-HighRes model plugin Red Screen of Death fix, take #2 * Move global options to _base. Rename Villain model * clipnorm and res block biases * scale the end of res block * res block * dfaker pre-activation res * OHRES pre-activation * villain pre-activation * tabs/space in nn_blocks * fix for histogram with mask all set to zero * fix to prevent two networks with same name * GUI: Wider tooltips. Improve TQDM capture * Fix regex bug * Convert padding=48 to ratio of image size * Add size option to alignments tool extract * Pass through training image size to convert from model * Convert: Pull training coverage from model * convert: coverage, blur and erode to percent * simplify matrix scaling * ordering of sliders in train * Add matrix scaling to utils. Use interpolation in lib.aligner transform * masked.py Import get_matrix_scaling from utils * fix circular import * Update masked.py * quick fix for matrix scaling * testing thus for now * tqdm regex capture bugfix * Minor ammends * blur size cleanup * Remove coverage option from convert (Now cascades from model) * Implement convert for all model types * Add mask option and coverage option to all existing models * bugfix for model loading on convert * debug print removal * Bugfix for masks in dfl_h128 and iae * Update preview display. Add preview scaling to cli * mask notes * Delete training_data_v2.py errant file * training data variables * Fix timelapse function * Add new config items to state file for legacy purposes * Slight GUI tweak * Raise exception if problem with loaded model * Add Tensorboard support (Logs stored in model directory) * ICNR fix * loss bugfix * convert bugfix * Move ini files to config folder. Make TensorBoard optional * Fix training data for unbalanced inputs/outputs * Fix config "none" test * Keep helptext in .ini files when saving config from GUI * Remove frame_dims from alignments * Add no-flip and warp-to-landmarks cli options * Revert OHR to RC4_fix version * Fix lowmem mode on OHR model * padding to variable * Save models in parallel threads * Speed-up of res_block stability * Automated Reflection Padding * Reflect Padding as a training option Includes auto-calculation of proper padding shapes, input_shapes, output_shapes Flag included in config now * rest of reflect padding * Move TB logging to cli. Session info to state file * Add session iterations to state file * Add recent files to menu. GUI code tidy up * [GUI] Fix recent file list update issue * Add correct loss names to TensorBoard logs * Update live graph to use TensorBoard and remove animation * Fix analysis tab. GUI optimizations * Analysis Graph popup to Tensorboard Logs * [GUI] Bug fix for graphing for models with hypens in name * [GUI] Correctly split loss to tabs during training * [GUI] Add loss type selection to analysis graph * Fix store command name in recent files. Switch to correct tab on open * [GUI] Disable training graph when 'no-logs' is selected * Fix graphing race condition * rename original_hires model to unbalanced
586 lines
24 KiB
Python
586 lines
24 KiB
Python
#!/usr/bin/env python3
|
|
""" Base class for Models. ALL Models should at least inherit from this class
|
|
|
|
When inheriting model_data should be a list of NNMeta objects.
|
|
See the class for details.
|
|
"""
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
from json import JSONDecodeError
|
|
|
|
from keras import losses
|
|
from keras.models import load_model
|
|
from keras.optimizers import Adam
|
|
from keras.utils import get_custom_objects, multi_gpu_model
|
|
|
|
from lib import Serializer
|
|
from lib.model.losses import DSSIMObjective, PenalizedLoss
|
|
from lib.model.nn_blocks import NNBlocks
|
|
from lib.multithreading import MultiThread
|
|
from plugins.train._config import Config
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
_CONFIG = None
|
|
|
|
|
|
class ModelBase():
|
|
""" Base class that all models should inherit from """
|
|
def __init__(self,
|
|
model_dir,
|
|
gpus,
|
|
no_logs=False,
|
|
warp_to_landmarks=False,
|
|
no_flip=False,
|
|
training_image_size=256,
|
|
alignments_paths=None,
|
|
preview_scale=100,
|
|
input_shape=None,
|
|
encoder_dim=None,
|
|
trainer="original",
|
|
predict=False):
|
|
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
|
|
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
|
|
"input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus,
|
|
training_image_size, alignments_paths, preview_scale, input_shape,
|
|
encoder_dim)
|
|
self.predict = predict
|
|
self.model_dir = model_dir
|
|
self.gpus = gpus
|
|
self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"],
|
|
use_icnr_init=self.config["icnr_init"],
|
|
use_reflect_padding=self.config["reflect_padding"])
|
|
self.input_shape = input_shape
|
|
self.output_shape = None # set after model is compiled
|
|
self.encoder_dim = encoder_dim
|
|
self.trainer = trainer
|
|
|
|
self.state = State(self.model_dir, self.name, no_logs, training_image_size)
|
|
self.load_state_info()
|
|
|
|
self.networks = dict() # Networks for the model
|
|
self.predictors = dict() # Predictors for model
|
|
self.history = dict() # Loss history per save iteration)
|
|
|
|
# Training information specific to the model should be placed in this
|
|
# dict for reference by the trainer.
|
|
self.training_opts = {"alignments": alignments_paths,
|
|
"preview_scaling": preview_scale / 100,
|
|
"warp_to_landmarks": warp_to_landmarks,
|
|
"no_flip": no_flip}
|
|
|
|
self.build()
|
|
self.set_training_data()
|
|
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
|
|
|
|
@property
|
|
def config(self):
|
|
""" Return config dict for current plugin """
|
|
global _CONFIG # pylint: disable=global-statement
|
|
if not _CONFIG:
|
|
model_name = ".".join(self.__module__.split(".")[-2:])
|
|
logger.debug("Loading config for: %s", model_name)
|
|
_CONFIG = Config(model_name).config_dict
|
|
return _CONFIG
|
|
|
|
@property
|
|
def name(self):
|
|
""" Set the model name based on the subclass """
|
|
basename = os.path.basename(sys.modules[self.__module__].__file__)
|
|
retval = os.path.splitext(basename)[0].lower()
|
|
logger.debug("model name: '%s'", retval)
|
|
return retval
|
|
|
|
def set_training_data(self):
|
|
""" Override to set model specific training data.
|
|
|
|
super() this method for defaults otherwise be sure to add """
|
|
logger.debug("Setting training data")
|
|
self.training_opts["training_size"] = self.state.training_size
|
|
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
|
|
self.training_opts["mask_type"] = self.config.get("mask_type", None)
|
|
self.training_opts["coverage_ratio"] = self.config.get("coverage", 62.5) / 100
|
|
self.training_opts["preview_images"] = 14
|
|
logger.debug("Set training data: %s", self.training_opts)
|
|
|
|
def build(self):
|
|
""" Build the model. Override for custom build methods """
|
|
self.add_networks()
|
|
self.load_models(swapped=False)
|
|
self.build_autoencoders()
|
|
self.log_summary()
|
|
self.compile_predictors()
|
|
|
|
def build_autoencoders(self):
|
|
""" Override for Model Specific autoencoder builds
|
|
|
|
NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names
|
|
are expected:
|
|
face (the input for image)
|
|
mask (the input for mask if it is used)
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def add_networks(self):
|
|
""" Override to add neural networks """
|
|
raise NotImplementedError
|
|
|
|
def load_state_info(self):
|
|
""" Load the input shape from state file if it exists """
|
|
logger.debug("Loading Input Shape from State file")
|
|
if not self.state.inputs:
|
|
logger.debug("No input shapes saved. Using model config")
|
|
return
|
|
if not self.state.face_shapes:
|
|
logger.warning("Input shapes stored in State file, but no matches for 'face'."
|
|
"Using model config")
|
|
return
|
|
input_shape = self.state.face_shapes[0]
|
|
logger.debug("Setting input shape from state file: %s", input_shape)
|
|
self.input_shape = input_shape
|
|
|
|
def add_network(self, network_type, side, network):
|
|
""" Add a NNMeta object """
|
|
logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network)
|
|
filename = "{}_{}".format(self.name, network_type.lower())
|
|
name = network_type.lower()
|
|
if side:
|
|
side = side.lower()
|
|
filename += "_{}".format(side.upper())
|
|
name += "_{}".format(side)
|
|
filename += ".h5"
|
|
logger.debug("name: '%s', filename: '%s'", name, filename)
|
|
self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network)
|
|
|
|
def add_predictor(self, side, model):
|
|
""" Add a predictor to the predictors dictionary """
|
|
logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
|
|
if self.gpus > 1:
|
|
logger.debug("Converting to multi-gpu: side %s", side)
|
|
model = multi_gpu_model(model, self.gpus)
|
|
self.predictors[side] = model
|
|
if not self.state.inputs:
|
|
self.store_input_shapes(model)
|
|
if not self.output_shape:
|
|
self.set_output_shape(model)
|
|
|
|
def store_input_shapes(self, model):
|
|
""" Store the input and output shapes to state """
|
|
logger.debug("Adding input shapes to state for model")
|
|
inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs}
|
|
if not any(inp for inp in inputs.keys() if inp.startswith("face")):
|
|
raise ValueError("No input named 'face' was found. Check your input naming. "
|
|
"Current input names: {}".format(inputs))
|
|
self.state.inputs = inputs
|
|
logger.debug("Added input shapes: %s", self.state.inputs)
|
|
|
|
def set_output_shape(self, model):
|
|
""" Set the output shape for use in training and convert """
|
|
logger.debug("Setting output shape")
|
|
out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs]
|
|
if not out:
|
|
raise ValueError("No outputs found! Check your model.")
|
|
self.output_shape = tuple(out[0])
|
|
logger.debug("Added output shape: %s", self.output_shape)
|
|
|
|
def compile_predictors(self):
|
|
""" Compile the predictors """
|
|
logger.debug("Compiling Predictors")
|
|
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0)
|
|
|
|
for side, model in self.predictors.items():
|
|
loss_names = ["loss"]
|
|
loss_funcs = [self.loss_function(side)]
|
|
mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
|
|
if mask:
|
|
loss_names.insert(0, "mask_loss")
|
|
loss_funcs.insert(0, self.mask_loss_function(mask[0], side))
|
|
model.compile(optimizer=optimizer, loss=loss_funcs)
|
|
|
|
if len(loss_names) > 1:
|
|
loss_names.insert(0, "total_loss")
|
|
self.state.add_session_loss_names(side, loss_names)
|
|
self.history[side] = list()
|
|
logger.debug("Compiled Predictors. Losses: %s", loss_names)
|
|
|
|
def loss_function(self, side):
|
|
""" Set the loss function """
|
|
if self.config.get("dssim_loss", False):
|
|
if side == "a" and not self.predict:
|
|
logger.verbose("Using DSSIM Loss")
|
|
loss_func = DSSIMObjective()
|
|
else:
|
|
if side == "a" and not self.predict:
|
|
logger.verbose("Using Mean Absolute Error Loss")
|
|
loss_func = losses.mean_absolute_error
|
|
logger.debug(loss_func)
|
|
return loss_func
|
|
|
|
def mask_loss_function(self, mask, side):
|
|
""" Set the loss function for masks
|
|
Side is input so we only log once """
|
|
if self.config.get("dssim_mask_loss", False):
|
|
if side == "a" and not self.predict:
|
|
logger.verbose("Using DSSIM Loss for mask")
|
|
mask_loss_func = DSSIMObjective()
|
|
else:
|
|
if side == "a" and not self.predict:
|
|
logger.verbose("Using Mean Absolute Error Loss for mask")
|
|
mask_loss_func = losses.mean_absolute_error
|
|
|
|
if self.config.get("penalized_mask_loss", False):
|
|
if side == "a" and not self.predict:
|
|
logger.verbose("Using Penalized Loss for mask")
|
|
mask_loss_func = PenalizedLoss(mask, mask_loss_func)
|
|
logger.debug(mask_loss_func)
|
|
return mask_loss_func
|
|
|
|
def converter(self, swap):
|
|
""" Converter for autoencoder models """
|
|
logger.debug("Getting Converter: (swap: %s)", swap)
|
|
if swap:
|
|
retval = self.predictors["a"].predict
|
|
else:
|
|
retval = self.predictors["b"].predict
|
|
logger.debug("Got Converter: %s", retval)
|
|
return retval
|
|
|
|
@property
|
|
def iterations(self):
|
|
"Get current training iteration number"
|
|
return self.state.iterations
|
|
|
|
def map_models(self, swapped):
|
|
""" Map the models for A/B side for swapping """
|
|
logger.debug("Map models: (swapped: %s)", swapped)
|
|
models_map = {"a": dict(), "b": dict()}
|
|
sides = ("a", "b") if not swapped else ("b", "a")
|
|
for network in self.networks.values():
|
|
if network.side == sides[0]:
|
|
models_map["a"][network.type] = network.filename
|
|
if network.side == sides[1]:
|
|
models_map["b"][network.type] = network.filename
|
|
logger.debug("Mapped models: (models_map: %s)", models_map)
|
|
return models_map
|
|
|
|
def log_summary(self):
|
|
""" Verbose log the model summaries """
|
|
if self.predict:
|
|
return
|
|
for side in sorted(list(self.predictors.keys())):
|
|
logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
|
|
self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x))
|
|
for name, nnmeta in self.networks.items():
|
|
if nnmeta.side is not None and nnmeta.side != side:
|
|
continue
|
|
logger.verbose("%s:", name.title())
|
|
nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x))
|
|
|
|
def load_models(self, swapped):
|
|
""" Load models from file """
|
|
logger.debug("Load model: (swapped: %s)", swapped)
|
|
model_mapping = self.map_models(swapped)
|
|
for network in self.networks.values():
|
|
if not network.side:
|
|
is_loaded = network.load(predict=self.predict)
|
|
else:
|
|
is_loaded = network.load(fullpath=model_mapping[network.side][network.type],
|
|
predict=self.predict)
|
|
if not is_loaded:
|
|
break
|
|
if is_loaded:
|
|
logger.info("Loaded model from disk: '%s'", self.model_dir)
|
|
return is_loaded
|
|
|
|
def save_models(self):
|
|
""" Backup and save the models """
|
|
logger.debug("Backing up and saving models")
|
|
should_backup = self.get_save_averages()
|
|
save_threads = list()
|
|
for network in self.networks.values():
|
|
name = "save_{}".format(network.name)
|
|
save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup))
|
|
save_threads.append(MultiThread(self.state.save,
|
|
name="save_state", should_backup=should_backup))
|
|
for thread in save_threads:
|
|
thread.start()
|
|
for thread in save_threads:
|
|
if thread.has_error:
|
|
logger.error(thread.errors[0])
|
|
thread.join()
|
|
# Put in a line break to avoid jumbled console
|
|
print("\n")
|
|
logger.info("saved models")
|
|
|
|
def get_save_averages(self):
|
|
""" Return the loss averages since last save and reset historical losses
|
|
|
|
This protects against model corruption by only backing up the model
|
|
if any of the loss values have fallen.
|
|
TODO This is not a perfect system. If the model corrupts on save_iteration - 1
|
|
then model may still backup
|
|
"""
|
|
logger.debug("Getting Average loss since last save")
|
|
avgs = dict()
|
|
backup = True
|
|
|
|
for side, loss in self.history.items():
|
|
if not loss:
|
|
backup = False
|
|
break
|
|
|
|
avgs[side] = sum(loss) / len(loss)
|
|
self.history[side] = list() # Reset historical loss
|
|
|
|
if not self.state.lowest_avg_loss.get(side, None):
|
|
logger.debug("Setting initial save iteration loss average for '%s': %s",
|
|
side, avgs[side])
|
|
self.state.lowest_avg_loss[side] = avgs[side]
|
|
continue
|
|
|
|
if backup:
|
|
# Only run this if backup is true. All losses must have dropped for a valid backup
|
|
backup = self.check_loss_drop(side, avgs[side])
|
|
|
|
logger.debug("Lowest historical save iteration loss average: %s",
|
|
self.state.lowest_avg_loss)
|
|
logger.debug("Average loss since last save: %s", avgs)
|
|
|
|
if backup: # Update lowest loss values to the state
|
|
for side, avg_loss in avgs.items():
|
|
logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
|
|
self.state.lowest_avg_loss[side] = avg_loss
|
|
|
|
logger.debug("Backing up: %s", backup)
|
|
return backup
|
|
|
|
def check_loss_drop(self, side, avg):
|
|
""" Check whether total loss has dropped since lowest loss """
|
|
if avg < self.state.lowest_avg_loss[side]:
|
|
logger.debug("Loss for '%s' has dropped", side)
|
|
return True
|
|
logger.debug("Loss for '%s' has not dropped", side)
|
|
return False
|
|
|
|
|
|
class NNMeta():
|
|
""" Class to hold a neural network and it's meta data
|
|
|
|
filename: The full path and filename of the model file for this network.
|
|
type: The type of network. For networks that can be swapped
|
|
The type should be identical for the corresponding
|
|
A and B networks, and should be unique for every A/B pair.
|
|
Otherwise the type should be completely unique.
|
|
side: A, B or None. Used to identify which networks can
|
|
be swapped.
|
|
network: Define network to this.
|
|
"""
|
|
|
|
def __init__(self, filename, network_type, side, network):
|
|
logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
|
|
"network: %s", self.__class__.__name__, filename, network_type,
|
|
side, network)
|
|
self.filename = filename
|
|
self.type = network_type.lower()
|
|
self.side = side
|
|
self.name = self.set_name()
|
|
self.network = network
|
|
self.network.name = self.name
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def set_name(self):
|
|
""" Set the network name """
|
|
name = self.type
|
|
if self.side:
|
|
name += "_{}".format(self.side)
|
|
return name
|
|
|
|
def load(self, fullpath=None, predict=False):
|
|
""" Load model """
|
|
fullpath = fullpath if fullpath else self.filename
|
|
logger.debug("Loading model: '%s'", fullpath)
|
|
try:
|
|
network = load_model(self.filename, custom_objects=get_custom_objects())
|
|
except ValueError as err:
|
|
if str(err).lower().startswith("cannot create group in read only mode"):
|
|
self.convert_legacy_weights()
|
|
return True
|
|
if predict:
|
|
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
|
|
logger.warning("Failed loading existing training data. Generating new models")
|
|
logger.debug("Exception: %s", str(err))
|
|
return False
|
|
except OSError as err: # pylint: disable=broad-except
|
|
if predict:
|
|
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
|
|
logger.warning("Failed loading existing training data. Generating new models")
|
|
logger.debug("Exception: %s", str(err))
|
|
return False
|
|
self.network = network # Update network with saved model
|
|
self.network.name = self.type
|
|
return True
|
|
|
|
def save(self, fullpath=None, should_backup=False):
|
|
""" Save model """
|
|
fullpath = fullpath if fullpath else self.filename
|
|
if should_backup:
|
|
self.backup(fullpath=fullpath)
|
|
logger.debug("Saving model: '%s'", fullpath)
|
|
self.network.save(fullpath)
|
|
|
|
def backup(self, fullpath=None):
|
|
""" Backup Model """
|
|
origfile = fullpath if fullpath else self.filename
|
|
backupfile = origfile + ".bk"
|
|
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
|
|
if os.path.exists(backupfile):
|
|
os.remove(backupfile)
|
|
if os.path.exists(origfile):
|
|
os.rename(origfile, backupfile)
|
|
|
|
def convert_legacy_weights(self):
|
|
""" Convert legacy weights files to hold the model topology """
|
|
logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
|
|
self.network.load_weights(self.filename)
|
|
self.save(should_backup=False)
|
|
self.network.name = self.type
|
|
|
|
|
|
class State():
|
|
""" Class to hold the model's current state and autoencoder structure """
|
|
def __init__(self, model_dir, model_name, no_logs, training_image_size):
|
|
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
|
|
"training_image_size: '%s'", self.__class__.__name__, model_dir,
|
|
model_name, no_logs, training_image_size)
|
|
self.serializer = Serializer.get_serializer("json")
|
|
filename = "{}_state.{}".format(model_name, self.serializer.ext)
|
|
self.filename = str(model_dir / filename)
|
|
self.iterations = 0
|
|
self.session_iterations = 0
|
|
self.training_size = training_image_size
|
|
self.sessions = dict()
|
|
self.lowest_avg_loss = dict()
|
|
self.inputs = dict()
|
|
self.config = dict()
|
|
self.load()
|
|
self.session_id = self.new_session_id()
|
|
self.create_new_session(no_logs)
|
|
logger.debug("Initialized %s:", self.__class__.__name__)
|
|
|
|
@property
|
|
def face_shapes(self):
|
|
""" Return a list of stored face shape inputs """
|
|
return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]
|
|
|
|
@property
|
|
def mask_shapes(self):
|
|
""" Return a list of stored mask shape inputs """
|
|
return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]
|
|
|
|
@property
|
|
def loss_names(self):
|
|
""" Return the loss names for this session """
|
|
return self.sessions[self.session_id]["loss_names"]
|
|
|
|
@property
|
|
def current_session(self):
|
|
""" Return the current session dict """
|
|
return self.sessions[self.session_id]
|
|
|
|
def new_session_id(self):
|
|
""" Return new session_id """
|
|
if not self.sessions:
|
|
session_id = 1
|
|
else:
|
|
session_id = max(int(key) for key in self.sessions.keys()) + 1
|
|
logger.debug(session_id)
|
|
return session_id
|
|
|
|
def create_new_session(self, no_logs):
|
|
""" Create a new session """
|
|
logger.debug("Creating new session. id: %s", self.session_id)
|
|
self.sessions[self.session_id] = {"timestamp": time.time(),
|
|
"no_logs": no_logs,
|
|
"loss_names": dict(),
|
|
"batchsize": 0,
|
|
"iterations": 0}
|
|
|
|
def add_session_loss_names(self, side, loss_names):
|
|
""" Add the session loss names to the sessions dictionary """
|
|
logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
|
|
self.sessions[self.session_id]["loss_names"][side] = loss_names
|
|
|
|
def add_session_batchsize(self, batchsize):
|
|
""" Add the session batchsize to the sessions dictionary """
|
|
logger.debug("Adding session batchsize: %s", batchsize)
|
|
self.sessions[self.session_id]["batchsize"] = batchsize
|
|
|
|
def increment_iterations(self):
|
|
""" Increment total and session iterations """
|
|
self.iterations += 1
|
|
self.sessions[self.session_id]["iterations"] += 1
|
|
|
|
def load(self):
|
|
""" Load state file """
|
|
logger.debug("Loading State")
|
|
try:
|
|
with open(self.filename, "rb") as inp:
|
|
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
|
|
self.sessions = state.get("sessions", dict())
|
|
self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
|
|
self.iterations = state.get("iterations", 0)
|
|
self.training_size = state.get("training_size", 256)
|
|
self.inputs = state.get("inputs", dict())
|
|
self.config = state.get("config", dict())
|
|
logger.debug("Loaded state: %s", state)
|
|
self.replace_config()
|
|
except IOError as err:
|
|
logger.warning("No existing state file found. Generating.")
|
|
logger.debug("IOError: %s", str(err))
|
|
except JSONDecodeError as err:
|
|
logger.debug("JSONDecodeError: %s:", str(err))
|
|
|
|
def save(self, should_backup=False):
|
|
""" Save iteration number to state file """
|
|
logger.debug("Saving State")
|
|
if should_backup:
|
|
self.backup()
|
|
try:
|
|
with open(self.filename, "wb") as out:
|
|
state = {"sessions": self.sessions,
|
|
"lowest_avg_loss": self.lowest_avg_loss,
|
|
"iterations": self.iterations,
|
|
"inputs": self.inputs,
|
|
"training_size": self.training_size,
|
|
"config": _CONFIG}
|
|
state_json = self.serializer.marshal(state)
|
|
out.write(state_json.encode("utf-8"))
|
|
except IOError as err:
|
|
logger.error("Unable to save model state: %s", str(err.strerror))
|
|
logger.debug("Saved State")
|
|
|
|
def backup(self):
|
|
""" Backup state file """
|
|
origfile = self.filename
|
|
backupfile = origfile + ".bk"
|
|
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
|
|
if os.path.exists(backupfile):
|
|
os.remove(backupfile)
|
|
if os.path.exists(origfile):
|
|
os.rename(origfile, backupfile)
|
|
|
|
def replace_config(self):
|
|
""" Replace the loaded config with the one contained within the state file """
|
|
global _CONFIG # pylint: disable=global-statement
|
|
# Add any new items to state config for legacy purposes
|
|
for key, val in _CONFIG.items():
|
|
if key not in self.config.keys():
|
|
logger.info("Adding new config item to state file: '%s': '%s'", key, val)
|
|
self.config[key] = val
|
|
logger.debug("Replacing config. Old config: %s", _CONFIG)
|
|
_CONFIG = self.config
|
|
logger.debug("Replaced config. New config: %s", _CONFIG)
|
|
logger.info("Using configuration saved in state file")
|