1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/train/model/_base.py
torzdf cd00859c40
model_refactor (#571) (#572)
* model_refactor (#571)

* original model to new structure

* IAE model to new structure

* OriginalHiRes to new structure

* Fix trainer for different resolutions

* Initial config implementation

* Configparse library added

* improved training data loader

* dfaker model working

* Add logging to training functions

* Non blocking input for cli training

* Add error handling to threads. Add non-mp queues to queue_handler

* Improved Model Building and NNMeta

* refactor lib/models

* training refactor. DFL H128 model Implementation

* Dfaker - use hashes

* Move timelapse. Remove perceptual loss arg

* Update INSTALL.md. Add logger formatting. Update Dfaker training

* DFL h128 partially ported

* Add mask to dfaker (#573)

* Remove old models. Add mask to dfaker

* dfl mask. Make masks selectable in config (#575)

* DFL H128 Mask. Mask type selectable in config.

* remove gan_v2_2

* Creating Input Size config for models

Creating Input Size config for models

Will be used downstream in converters.

Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes)

* Add mask loss options to config

* MTCNN options to config.ini. Remove GAN config. Update USAGE.md

* Add sliders for numerical values in GUI

* Add config plugins menu to gui. Validate config

* Only backup model if loss has dropped. Get training working again

* bugfixes

* Standardise loss printing

* GUI idle cpu fixes. Graph loss fix.

* mutli-gpu logging bugfix

* Merge branch 'staging' into train_refactor

* backup state file

* Crash protection: Only backup if both total losses have dropped

* Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes)

* Load and save model structure with weights

* Slight code update

* Improve config loader. Add subpixel opt to all models. Config to state

* Show samples... wrong input

* Remove AE topology. Add input/output shapes to State

* Port original_villain (birb/VillainGuy) model to faceswap

* Add plugin info to GUI config pages

* Load input shape from state. IAE Config options.

* Fix transform_kwargs.
Coverage to ratio.
Bugfix mask detection

* Suppress keras userwarnings.
Automate zoom.
Coverage_ratio to model def.

* Consolidation of converters & refactor (#574)

* Consolidation of converters & refactor

Initial Upload of alpha

Items
- consolidate convert_mased & convert_adjust into one converter
-add average color adjust to convert_masked
-allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size
-allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size
-eliminate redundant type conversions to avoid multiple round-off errors
-refactor loops for vectorization/speed
-reorganize for clarity & style changes

TODO
- bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now
- issues with mask border giving black ring at zero erosion .. investigate
- remove GAN ??
- test enlargment factors of umeyama standard face .. match to coverage factor
- make enlargment factor a model parameter
- remove convert_adjusted and referencing code when finished

* Update Convert_Masked.py

default blur size of 2 to match original...
description of enlargement tests
breakout matrxi scaling into def

* Enlargment scale as a cli parameter

* Update cli.py

* dynamic interpolation algorithm

Compute x & y scale factors from the affine matrix on the fly by QR decomp.
Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image

* input size
input size from config

* fix issues with <1.0 erosion

* Update convert.py

* Update Convert_Adjust.py

more work on the way to merginf

* Clean up help note on sharpen

* cleanup seamless

* Delete Convert_Adjust.py

* Update umeyama.py

* Update training_data.py

* swapping

* segmentation stub

* changes to convert.str

* Update masked.py

* Backwards compatibility fix for models
Get converter running

* Convert:
Move masks to class.
bugfix blur_size
some linting

* mask fix

* convert fixes

- missing facehull_rect re-added
- coverage to %
- corrected coverage logic
- cleanup of gui option ordering

* Update cli.py

* default for blur

* Update masked.py

* added preliminary low_mem version of OriginalHighRes model plugin

* Code cleanup, minor fixes

* Update masked.py

* Update masked.py

* Add dfl mask to convert

* histogram fix & seamless location

* update

* revert

* bugfix: Load actual configuration in gui

* Standardize nn_blocks

* Update cli.py

* Minor code amends

* Fix Original HiRes model

* Add masks to preview output for mask trainers
refactor trainer.__base.py

* Masked trainers converter support

* convert bugfix

* Bugfix: Converter for masked (dfl/dfaker) trainers

* Additional Losses (#592)

* initial upload

* Delete blur.py

* default initializer = He instead of Glorot (#588)

* Allow kernel_initializer to be overridable

* Add ICNR Initializer option for upscale on all models.

* Hopefully fixes RSoDs with original-highres model plugin

* remove debug line

* Original-HighRes model plugin Red Screen of Death fix, take #2

* Move global options to _base. Rename Villain model

* clipnorm and res block biases

* scale the end of res block

* res block

* dfaker pre-activation res

* OHRES pre-activation

* villain pre-activation

* tabs/space in nn_blocks

* fix for histogram with mask all set to zero

* fix to prevent two networks with same name

* GUI: Wider tooltips. Improve TQDM capture

* Fix regex bug

* Convert padding=48 to ratio of image size

* Add size option to alignments tool extract

* Pass through training image size to convert from model

* Convert: Pull training coverage from model

* convert: coverage, blur and erode to percent

* simplify matrix scaling

* ordering of sliders in train

* Add matrix scaling to utils. Use interpolation in lib.aligner transform

* masked.py Import get_matrix_scaling from utils

* fix circular import

* Update masked.py

* quick fix for matrix scaling

* testing thus for now

* tqdm regex capture bugfix

* Minor ammends

* blur size cleanup

* Remove coverage option from convert (Now cascades from model)

* Implement convert for all model types

* Add mask option and coverage option to all existing models

* bugfix for model loading on convert

* debug print removal

* Bugfix for masks in dfl_h128 and iae

* Update preview display. Add preview scaling to cli

* mask notes

* Delete training_data_v2.py

errant file

* training data variables

* Fix timelapse function

* Add new config items to state file for legacy purposes

* Slight GUI tweak

* Raise exception if problem with loaded model

* Add Tensorboard support (Logs stored in model directory)

* ICNR fix

* loss bugfix

* convert bugfix

* Move ini files to config folder. Make TensorBoard optional

* Fix training data for unbalanced inputs/outputs

* Fix config "none" test

* Keep helptext in .ini files when saving config from GUI

* Remove frame_dims from alignments

* Add no-flip and warp-to-landmarks cli options

* Revert OHR to RC4_fix version

* Fix lowmem mode on OHR model

* padding to variable

* Save models in parallel threads

* Speed-up of res_block stability

* Automated Reflection Padding

* Reflect Padding as a training option

Includes auto-calculation of proper padding shapes, input_shapes, output_shapes

Flag included in config now

* rest of reflect padding

* Move TB logging to cli. Session info to state file

* Add session iterations to state file

* Add recent files to menu. GUI code tidy up

* [GUI] Fix recent file list update issue

* Add correct loss names to TensorBoard logs

* Update live graph to use TensorBoard and remove animation

* Fix analysis tab. GUI optimizations

* Analysis Graph popup to Tensorboard Logs

* [GUI] Bug fix for graphing for models with hypens in name

* [GUI] Correctly split loss to tabs during training

* [GUI] Add loss type selection to analysis graph

* Fix store command name in recent files. Switch to correct tab on open

* [GUI] Disable training graph when 'no-logs' is selected

* Fix graphing race condition

* rename original_hires model to unbalanced
2019-02-09 18:35:12 +00:00

586 lines
24 KiB
Python

#!/usr/bin/env python3
""" Base class for Models. ALL Models should at least inherit from this class
When inheriting model_data should be a list of NNMeta objects.
See the class for details.
"""
import logging
import os
import sys
import time
from json import JSONDecodeError
from keras import losses
from keras.models import load_model
from keras.optimizers import Adam
from keras.utils import get_custom_objects, multi_gpu_model
from lib import Serializer
from lib.model.losses import DSSIMObjective, PenalizedLoss
from lib.model.nn_blocks import NNBlocks
from lib.multithreading import MultiThread
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG = None
class ModelBase():
""" Base class that all models should inherit from """
def __init__(self,
model_dir,
gpus,
no_logs=False,
warp_to_landmarks=False,
no_flip=False,
training_image_size=256,
alignments_paths=None,
preview_scale=100,
input_shape=None,
encoder_dim=None,
trainer="original",
predict=False):
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, "
"training_image_size, %s, alignments_paths: %s, preview_scale: %s, "
"input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus,
training_image_size, alignments_paths, preview_scale, input_shape,
encoder_dim)
self.predict = predict
self.model_dir = model_dir
self.gpus = gpus
self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"],
use_icnr_init=self.config["icnr_init"],
use_reflect_padding=self.config["reflect_padding"])
self.input_shape = input_shape
self.output_shape = None # set after model is compiled
self.encoder_dim = encoder_dim
self.trainer = trainer
self.state = State(self.model_dir, self.name, no_logs, training_image_size)
self.load_state_info()
self.networks = dict() # Networks for the model
self.predictors = dict() # Predictors for model
self.history = dict() # Loss history per save iteration)
# Training information specific to the model should be placed in this
# dict for reference by the trainer.
self.training_opts = {"alignments": alignments_paths,
"preview_scaling": preview_scale / 100,
"warp_to_landmarks": warp_to_landmarks,
"no_flip": no_flip}
self.build()
self.set_training_data()
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
@property
def config(self):
""" Return config dict for current plugin """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = ".".join(self.__module__.split(".")[-2:])
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name).config_dict
return _CONFIG
@property
def name(self):
""" Set the model name based on the subclass """
basename = os.path.basename(sys.modules[self.__module__].__file__)
retval = os.path.splitext(basename)[0].lower()
logger.debug("model name: '%s'", retval)
return retval
def set_training_data(self):
""" Override to set model specific training data.
super() this method for defaults otherwise be sure to add """
logger.debug("Setting training data")
self.training_opts["training_size"] = self.state.training_size
self.training_opts["no_logs"] = self.state.current_session["no_logs"]
self.training_opts["mask_type"] = self.config.get("mask_type", None)
self.training_opts["coverage_ratio"] = self.config.get("coverage", 62.5) / 100
self.training_opts["preview_images"] = 14
logger.debug("Set training data: %s", self.training_opts)
def build(self):
""" Build the model. Override for custom build methods """
self.add_networks()
self.load_models(swapped=False)
self.build_autoencoders()
self.log_summary()
self.compile_predictors()
def build_autoencoders(self):
""" Override for Model Specific autoencoder builds
NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names
are expected:
face (the input for image)
mask (the input for mask if it is used)
"""
raise NotImplementedError
def add_networks(self):
""" Override to add neural networks """
raise NotImplementedError
def load_state_info(self):
""" Load the input shape from state file if it exists """
logger.debug("Loading Input Shape from State file")
if not self.state.inputs:
logger.debug("No input shapes saved. Using model config")
return
if not self.state.face_shapes:
logger.warning("Input shapes stored in State file, but no matches for 'face'."
"Using model config")
return
input_shape = self.state.face_shapes[0]
logger.debug("Setting input shape from state file: %s", input_shape)
self.input_shape = input_shape
def add_network(self, network_type, side, network):
""" Add a NNMeta object """
logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network)
filename = "{}_{}".format(self.name, network_type.lower())
name = network_type.lower()
if side:
side = side.lower()
filename += "_{}".format(side.upper())
name += "_{}".format(side)
filename += ".h5"
logger.debug("name: '%s', filename: '%s'", name, filename)
self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network)
def add_predictor(self, side, model):
""" Add a predictor to the predictors dictionary """
logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
if self.gpus > 1:
logger.debug("Converting to multi-gpu: side %s", side)
model = multi_gpu_model(model, self.gpus)
self.predictors[side] = model
if not self.state.inputs:
self.store_input_shapes(model)
if not self.output_shape:
self.set_output_shape(model)
def store_input_shapes(self, model):
""" Store the input and output shapes to state """
logger.debug("Adding input shapes to state for model")
inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs}
if not any(inp for inp in inputs.keys() if inp.startswith("face")):
raise ValueError("No input named 'face' was found. Check your input naming. "
"Current input names: {}".format(inputs))
self.state.inputs = inputs
logger.debug("Added input shapes: %s", self.state.inputs)
def set_output_shape(self, model):
""" Set the output shape for use in training and convert """
logger.debug("Setting output shape")
out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs]
if not out:
raise ValueError("No outputs found! Check your model.")
self.output_shape = tuple(out[0])
logger.debug("Added output shape: %s", self.output_shape)
def compile_predictors(self):
""" Compile the predictors """
logger.debug("Compiling Predictors")
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0)
for side, model in self.predictors.items():
loss_names = ["loss"]
loss_funcs = [self.loss_function(side)]
mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
if mask:
loss_names.insert(0, "mask_loss")
loss_funcs.insert(0, self.mask_loss_function(mask[0], side))
model.compile(optimizer=optimizer, loss=loss_funcs)
if len(loss_names) > 1:
loss_names.insert(0, "total_loss")
self.state.add_session_loss_names(side, loss_names)
self.history[side] = list()
logger.debug("Compiled Predictors. Losses: %s", loss_names)
def loss_function(self, side):
""" Set the loss function """
if self.config.get("dssim_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using DSSIM Loss")
loss_func = DSSIMObjective()
else:
if side == "a" and not self.predict:
logger.verbose("Using Mean Absolute Error Loss")
loss_func = losses.mean_absolute_error
logger.debug(loss_func)
return loss_func
def mask_loss_function(self, mask, side):
""" Set the loss function for masks
Side is input so we only log once """
if self.config.get("dssim_mask_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using DSSIM Loss for mask")
mask_loss_func = DSSIMObjective()
else:
if side == "a" and not self.predict:
logger.verbose("Using Mean Absolute Error Loss for mask")
mask_loss_func = losses.mean_absolute_error
if self.config.get("penalized_mask_loss", False):
if side == "a" and not self.predict:
logger.verbose("Using Penalized Loss for mask")
mask_loss_func = PenalizedLoss(mask, mask_loss_func)
logger.debug(mask_loss_func)
return mask_loss_func
def converter(self, swap):
""" Converter for autoencoder models """
logger.debug("Getting Converter: (swap: %s)", swap)
if swap:
retval = self.predictors["a"].predict
else:
retval = self.predictors["b"].predict
logger.debug("Got Converter: %s", retval)
return retval
@property
def iterations(self):
"Get current training iteration number"
return self.state.iterations
def map_models(self, swapped):
""" Map the models for A/B side for swapping """
logger.debug("Map models: (swapped: %s)", swapped)
models_map = {"a": dict(), "b": dict()}
sides = ("a", "b") if not swapped else ("b", "a")
for network in self.networks.values():
if network.side == sides[0]:
models_map["a"][network.type] = network.filename
if network.side == sides[1]:
models_map["b"][network.type] = network.filename
logger.debug("Mapped models: (models_map: %s)", models_map)
return models_map
def log_summary(self):
""" Verbose log the model summaries """
if self.predict:
return
for side in sorted(list(self.predictors.keys())):
logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x))
for name, nnmeta in self.networks.items():
if nnmeta.side is not None and nnmeta.side != side:
continue
logger.verbose("%s:", name.title())
nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x))
def load_models(self, swapped):
""" Load models from file """
logger.debug("Load model: (swapped: %s)", swapped)
model_mapping = self.map_models(swapped)
for network in self.networks.values():
if not network.side:
is_loaded = network.load(predict=self.predict)
else:
is_loaded = network.load(fullpath=model_mapping[network.side][network.type],
predict=self.predict)
if not is_loaded:
break
if is_loaded:
logger.info("Loaded model from disk: '%s'", self.model_dir)
return is_loaded
def save_models(self):
""" Backup and save the models """
logger.debug("Backing up and saving models")
should_backup = self.get_save_averages()
save_threads = list()
for network in self.networks.values():
name = "save_{}".format(network.name)
save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup))
save_threads.append(MultiThread(self.state.save,
name="save_state", should_backup=should_backup))
for thread in save_threads:
thread.start()
for thread in save_threads:
if thread.has_error:
logger.error(thread.errors[0])
thread.join()
# Put in a line break to avoid jumbled console
print("\n")
logger.info("saved models")
def get_save_averages(self):
""" Return the loss averages since last save and reset historical losses
This protects against model corruption by only backing up the model
if any of the loss values have fallen.
TODO This is not a perfect system. If the model corrupts on save_iteration - 1
then model may still backup
"""
logger.debug("Getting Average loss since last save")
avgs = dict()
backup = True
for side, loss in self.history.items():
if not loss:
backup = False
break
avgs[side] = sum(loss) / len(loss)
self.history[side] = list() # Reset historical loss
if not self.state.lowest_avg_loss.get(side, None):
logger.debug("Setting initial save iteration loss average for '%s': %s",
side, avgs[side])
self.state.lowest_avg_loss[side] = avgs[side]
continue
if backup:
# Only run this if backup is true. All losses must have dropped for a valid backup
backup = self.check_loss_drop(side, avgs[side])
logger.debug("Lowest historical save iteration loss average: %s",
self.state.lowest_avg_loss)
logger.debug("Average loss since last save: %s", avgs)
if backup: # Update lowest loss values to the state
for side, avg_loss in avgs.items():
logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
self.state.lowest_avg_loss[side] = avg_loss
logger.debug("Backing up: %s", backup)
return backup
def check_loss_drop(self, side, avg):
""" Check whether total loss has dropped since lowest loss """
if avg < self.state.lowest_avg_loss[side]:
logger.debug("Loss for '%s' has dropped", side)
return True
logger.debug("Loss for '%s' has not dropped", side)
return False
class NNMeta():
""" Class to hold a neural network and it's meta data
filename: The full path and filename of the model file for this network.
type: The type of network. For networks that can be swapped
The type should be identical for the corresponding
A and B networks, and should be unique for every A/B pair.
Otherwise the type should be completely unique.
side: A, B or None. Used to identify which networks can
be swapped.
network: Define network to this.
"""
def __init__(self, filename, network_type, side, network):
logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
"network: %s", self.__class__.__name__, filename, network_type,
side, network)
self.filename = filename
self.type = network_type.lower()
self.side = side
self.name = self.set_name()
self.network = network
self.network.name = self.name
logger.debug("Initialized %s", self.__class__.__name__)
def set_name(self):
""" Set the network name """
name = self.type
if self.side:
name += "_{}".format(self.side)
return name
def load(self, fullpath=None, predict=False):
""" Load model """
fullpath = fullpath if fullpath else self.filename
logger.debug("Loading model: '%s'", fullpath)
try:
network = load_model(self.filename, custom_objects=get_custom_objects())
except ValueError as err:
if str(err).lower().startswith("cannot create group in read only mode"):
self.convert_legacy_weights()
return True
if predict:
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
except OSError as err: # pylint: disable=broad-except
if predict:
raise ValueError("Unable to load training data. Error: {}".format(str(err)))
logger.warning("Failed loading existing training data. Generating new models")
logger.debug("Exception: %s", str(err))
return False
self.network = network # Update network with saved model
self.network.name = self.type
return True
def save(self, fullpath=None, should_backup=False):
""" Save model """
fullpath = fullpath if fullpath else self.filename
if should_backup:
self.backup(fullpath=fullpath)
logger.debug("Saving model: '%s'", fullpath)
self.network.save(fullpath)
def backup(self, fullpath=None):
""" Backup Model """
origfile = fullpath if fullpath else self.filename
backupfile = origfile + ".bk"
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
if os.path.exists(backupfile):
os.remove(backupfile)
if os.path.exists(origfile):
os.rename(origfile, backupfile)
def convert_legacy_weights(self):
""" Convert legacy weights files to hold the model topology """
logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
self.network.load_weights(self.filename)
self.save(should_backup=False)
self.network.name = self.type
class State():
""" Class to hold the model's current state and autoencoder structure """
def __init__(self, model_dir, model_name, no_logs, training_image_size):
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, "
"training_image_size: '%s'", self.__class__.__name__, model_dir,
model_name, no_logs, training_image_size)
self.serializer = Serializer.get_serializer("json")
filename = "{}_state.{}".format(model_name, self.serializer.ext)
self.filename = str(model_dir / filename)
self.iterations = 0
self.session_iterations = 0
self.training_size = training_image_size
self.sessions = dict()
self.lowest_avg_loss = dict()
self.inputs = dict()
self.config = dict()
self.load()
self.session_id = self.new_session_id()
self.create_new_session(no_logs)
logger.debug("Initialized %s:", self.__class__.__name__)
@property
def face_shapes(self):
""" Return a list of stored face shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]
@property
def mask_shapes(self):
""" Return a list of stored mask shape inputs """
return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]
@property
def loss_names(self):
""" Return the loss names for this session """
return self.sessions[self.session_id]["loss_names"]
@property
def current_session(self):
""" Return the current session dict """
return self.sessions[self.session_id]
def new_session_id(self):
""" Return new session_id """
if not self.sessions:
session_id = 1
else:
session_id = max(int(key) for key in self.sessions.keys()) + 1
logger.debug(session_id)
return session_id
def create_new_session(self, no_logs):
""" Create a new session """
logger.debug("Creating new session. id: %s", self.session_id)
self.sessions[self.session_id] = {"timestamp": time.time(),
"no_logs": no_logs,
"loss_names": dict(),
"batchsize": 0,
"iterations": 0}
def add_session_loss_names(self, side, loss_names):
""" Add the session loss names to the sessions dictionary """
logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
self.sessions[self.session_id]["loss_names"][side] = loss_names
def add_session_batchsize(self, batchsize):
""" Add the session batchsize to the sessions dictionary """
logger.debug("Adding session batchsize: %s", batchsize)
self.sessions[self.session_id]["batchsize"] = batchsize
def increment_iterations(self):
""" Increment total and session iterations """
self.iterations += 1
self.sessions[self.session_id]["iterations"] += 1
def load(self):
""" Load state file """
logger.debug("Loading State")
try:
with open(self.filename, "rb") as inp:
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
self.sessions = state.get("sessions", dict())
self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
self.iterations = state.get("iterations", 0)
self.training_size = state.get("training_size", 256)
self.inputs = state.get("inputs", dict())
self.config = state.get("config", dict())
logger.debug("Loaded state: %s", state)
self.replace_config()
except IOError as err:
logger.warning("No existing state file found. Generating.")
logger.debug("IOError: %s", str(err))
except JSONDecodeError as err:
logger.debug("JSONDecodeError: %s:", str(err))
def save(self, should_backup=False):
""" Save iteration number to state file """
logger.debug("Saving State")
if should_backup:
self.backup()
try:
with open(self.filename, "wb") as out:
state = {"sessions": self.sessions,
"lowest_avg_loss": self.lowest_avg_loss,
"iterations": self.iterations,
"inputs": self.inputs,
"training_size": self.training_size,
"config": _CONFIG}
state_json = self.serializer.marshal(state)
out.write(state_json.encode("utf-8"))
except IOError as err:
logger.error("Unable to save model state: %s", str(err.strerror))
logger.debug("Saved State")
def backup(self):
""" Backup state file """
origfile = self.filename
backupfile = origfile + ".bk"
logger.debug("Backing up: '%s' to '%s'", origfile, backupfile)
if os.path.exists(backupfile):
os.remove(backupfile)
if os.path.exists(origfile):
os.rename(origfile, backupfile)
def replace_config(self):
""" Replace the loaded config with the one contained within the state file """
global _CONFIG # pylint: disable=global-statement
# Add any new items to state config for legacy purposes
for key, val in _CONFIG.items():
if key not in self.config.keys():
logger.info("Adding new config item to state file: '%s': '%s'", key, val)
self.config[key] = val
logger.debug("Replacing config. Old config: %s", _CONFIG)
_CONFIG = self.config
logger.debug("Replaced config. New config: %s", _CONFIG)
logger.info("Using configuration saved in state file")