1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/train/model/_base/model.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

939 lines
40 KiB
Python

#!/usr/bin/env python3
"""
Base class for Models. ALL Models should at least inherit from this class.
See :mod:`~plugins.train.model.original` for an annotated example for how to create model plugins.
"""
from __future__ import annotations
import logging
import os
import sys
import time
import typing as T
from collections import OrderedDict
import numpy as np
import tensorflow as tf
# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.keras import backend as K # pylint:disable=import-error
from tensorflow.keras.layers import Input # pylint:disable=import-error
from tensorflow.keras.models import load_model, Model as KModel # noqa:E501 # pylint:disable=import-error
from lib.serializer import get_serializer
from lib.model.nn_blocks import set_config as set_nnblock_config
from lib.utils import FaceswapError
from plugins.train._config import Config
from .io import IO, get_all_sub_models, Weights
from .settings import Loss, Optimizer, Settings
if T.TYPE_CHECKING:
import argparse
from lib.config import ConfigValueType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG: dict[str, ConfigValueType] = {}
class ModelBase():
""" Base class that all model plugins should inherit from.
Parameters
----------
model_dir: str
The full path to the model save location
arguments: :class:`argparse.Namespace`
The arguments that were passed to the train or convert process as generated from
Faceswap's command line arguments
predict: bool, optional
``True`` if the model is being loaded for inference, ``False`` if the model is being loaded
for training. Default: ``False``
Attributes
----------
input_shape: tuple or list
A `tuple` of `ints` defining the shape of the faces that the model takes as input. This
should be overridden by model plugins in their :func:`__init__` function. If the input size
is the same for both sides of the model, then this can be a single 3 dimensional `tuple`.
If the inputs have different sizes for `"A"` and `"B"` this should be a `list` of 2 3
dimensional shape `tuples`, 1 for each side respectively.
trainer: str
Currently there is only one trainer available (`"original"`), so at present this attribute
can be ignored. If/when more trainers are added, then this attribute should be overridden
with the trainer name that a model requires in the model plugin's
:func:`__init__` function.
"""
def __init__(self,
model_dir: str,
arguments: argparse.Namespace,
predict: bool = False) -> None:
logger.debug("Initializing ModelBase (%s): (model_dir: '%s', arguments: %s, predict: %s)",
self.__class__.__name__, model_dir, arguments, predict)
# Input shape must be set within the plugin after initializing
self.input_shape: tuple[int, ...] = ()
self.trainer = "original" # Override for plugin specific trainer
self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
self._args = arguments
self._is_predict = predict
self._model: tf.keras.models.Model | None = None
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
self._load_config()
if self.config["penalized_mask_loss"] and self.config["mask_type"] is None:
raise FaceswapError("Penalized Mask Loss has been selected but you have not chosen a "
"Mask to use. Please select a mask or disable Penalized Mask "
"Loss.")
if self.config["learn_mask"] and self.config["mask_type"] is None:
raise FaceswapError("'Learn Mask' has been selected but you have not chosen a Mask to "
"use. Please select a mask or disable 'Learn Mask'.")
self._mixed_precision = self.config["mixed_precision"]
self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
self._check_multiple_models()
self._state = State(model_dir,
self.name,
self._config_changeable_items,
False if self._is_predict else self._args.no_logs)
self._settings = Settings(self._args,
self._mixed_precision,
self.config["allow_growth"],
self._is_predict)
self._loss = Loss(self.config, self.color_order)
logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
@property
def model(self) -> tf.keras.models.Model:
""":class:`Keras.models.Model`: The compiled model for this plugin. """
return self._model
@property
def command_line_arguments(self) -> argparse.Namespace:
""" :class:`argparse.Namespace`: The command line arguments passed to the model plugin from
either the train or convert script """
return self._args
@property
def coverage_ratio(self) -> float:
""" float: The ratio of the training image to crop out and train on as defined in user
configuration options.
NB: The coverage ratio is a raw float, but will be applied to integer pixel images.
To ensure consistent rounding and guaranteed even image size, the calculation for coverage
should always be: :math:`(original_size * coverage_ratio // 2) * 2`
"""
return self.config.get("coverage", 62.5) / 100
@property
def model_dir(self) -> str:
"""str: The full path to the model folder location. """
return self._io._model_dir # pylint:disable=protected-access
@property
def config(self) -> dict:
""" dict: The configuration dictionary for current plugin, as set by the user's
configuration settings. """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = self._config_section
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name, configfile=self._configfile).config_dict
return _CONFIG
@property
def name(self) -> str:
""" str: The name of this model based on the plugin name. """
_name = sys.modules[self.__module__].__file__
assert isinstance(_name, str)
return os.path.splitext(os.path.basename(_name))[0].lower()
@property
def model_name(self) -> str:
""" str: The name of the keras model. Generally this will be the same as :attr:`name`
but some plugins will override this when they contain multiple architectures """
return self.name
@property
def input_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the inputs to the model. """
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(inputs))
for inputs in self.model.inputs]
return shapes
@property
def output_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the outputs of the model. """
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(output))
for output in self.model.outputs]
return shapes
@property
def iterations(self) -> int:
""" int: The total number of iterations that the model has trained. """
return self._state.iterations
# Private properties
@property
def _config_section(self) -> str:
""" str: The section name for the current plugin for loading configuration options from the
config file. """
return ".".join(self.__module__.split(".")[-2:])
@property
def _config_changeable_items(self) -> dict:
""" dict: The configuration options that can be updated after the model has already been
created. """
return Config(self._config_section, configfile=self._configfile).changeable_items
@property
def state(self) -> "State":
""":class:`State`: The state settings for the current plugin. """
return self._state
def _load_config(self) -> None:
""" Load the global config for reference in :attr:`config` and set the faceswap blocks
configuration options in `lib.model.nn_blocks` """
global _CONFIG # pylint: disable=global-statement
if not _CONFIG:
model_name = self._config_section
logger.debug("Loading config for: %s", model_name)
_CONFIG = Config(model_name, configfile=self._configfile).config_dict
nn_block_keys = ['icnr_init', 'conv_aware_init', 'reflect_padding']
set_nnblock_config({key: _CONFIG.pop(key)
for key in nn_block_keys})
def _check_multiple_models(self) -> None:
""" Check whether multiple models exist in the model folder, and that no models exist that
were trained with a different plugin than the requested plugin.
Raises
------
FaceswapError
If multiple model files, or models for a different plugin from that requested exists
within the model folder
"""
multiple_models = self._io.multiple_models_in_folder
if multiple_models is None:
logger.debug("Contents of model folder are valid")
return
if len(multiple_models) == 1:
msg = (f"You have requested to train with the '{self.name}' plugin, but a model file "
f"for the '{multiple_models[0]}' plugin already exists in the folder "
f"'{self.model_dir}'.\nPlease select a different model folder.")
else:
ptypes = "', '".join(multiple_models)
msg = (f"There are multiple plugin types ('{ptypes}') stored in the model folder '"
f"{self.model_dir}'. This is not supported.\nPlease split the model files into "
"their own folders before proceeding")
raise FaceswapError(msg)
def build(self) -> None:
""" Build the model and assign to :attr:`model`.
Within the defined strategy scope, either builds the model from scratch or loads an
existing model if one exists.
If running inference, then the model is built only for the required side to perform the
swap function, otherwise the model is then compiled with the optimizer and chosen
loss function(s).
Finally, a model summary is outputted to the logger at verbose level.
"""
self._update_legacy_models()
is_summary = hasattr(self._args, "summary") and self._args.summary
with self._settings.strategy_scope():
if self._io.model_exists:
model = self._io._load() # pylint:disable=protected-access
if self._is_predict:
inference = _Inference(model, self._args.swap_model)
self._model = inference.model
else:
self._model = model
else:
self._validate_input_shape()
inputs = self._get_inputs()
if not self._settings.use_mixed_precision and not is_summary:
# Store layer names which can be switched to mixed precision
self._state.add_mixed_precision_layers(
self._settings.get_mixed_precision_layers(self.build_model, inputs))
self._model = self.build_model(inputs)
if not is_summary and not self._is_predict:
self._compile_model()
self._output_summary()
def _update_legacy_models(self) -> None:
""" Load weights from legacy split models into new unified model, archiving old model files
to a new folder. """
legacy_mapping = self._legacy_mapping() # pylint:disable=assignment-from-none
if legacy_mapping is None:
return
if not all(os.path.isfile(os.path.join(self.model_dir, fname))
for fname in legacy_mapping):
return
archive_dir = f"{self.model_dir}_TF1_Archived"
if os.path.exists(archive_dir):
raise FaceswapError("We need to update your model files for use with Tensorflow 2.x, "
"but the archive folder already exists. Please remove the "
f"following folder to continue: '{archive_dir}'")
logger.info("Updating legacy models for Tensorflow 2.x")
logger.info("Your Tensorflow 1.x models will be archived in the following location: '%s'",
archive_dir)
os.rename(self.model_dir, archive_dir)
os.mkdir(self.model_dir)
new_model = self.build_model(self._get_inputs())
for model_name, layer_name in legacy_mapping.items():
old_model: tf.keras.models.Model = load_model(os.path.join(archive_dir, model_name),
compile=False)
layer = [layer for layer in new_model.layers if layer.name == layer_name]
if not layer:
logger.warning("Skipping legacy weights from '%s'...", model_name)
continue
klayer: tf.keras.layers.Layer = layer[0]
logger.info("Updating legacy weights from '%s'...", model_name)
klayer.set_weights(old_model.get_weights())
filename = self._io._filename # pylint:disable=protected-access
logger.info("Saving Tensorflow 2.x model to '%s'", filename)
new_model.save(filename)
# Penalized Loss and Learn Mask used to be disabled automatically if a mask wasn't
# selected, so disable it if enabled, but mask_type is None
if self.config["mask_type"] is None:
self.config["penalized_mask_loss"] = False
self.config["learn_mask"] = False
self.config["eye_multiplier"] = 1
self.config["mouth_multiplier"] = 1
self._state.save()
def _validate_input_shape(self) -> None:
""" Validate that the input shape is either a single shape tuple of 3 dimensions or
a list of 2 shape tuples of 3 dimensions. """
assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple"
def _get_inputs(self) -> list[tf.keras.layers.Input]:
""" Obtain the standardized inputs for the model.
The inputs will be returned for the "A" and "B" sides in the shape as defined by
:attr:`input_shape`.
Returns
-------
list
A list of :class:`keras.layers.Input` tensors. This will be a list of 2 tensors (one
for each side) each of shapes :attr:`input_shape`.
"""
logger.debug("Getting inputs")
input_shapes = [self.input_shape, self.input_shape]
inputs = [Input(shape=shape, name=f"face_in_{side}")
for side, shape in zip(("a", "b"), input_shapes)]
logger.debug("inputs: %s", inputs)
return inputs
def build_model(self, inputs: list[tf.keras.layers.Input]) -> tf.keras.models.Model:
""" Override for Model Specific autoencoder builds.
Parameters
----------
inputs: list
A list of :class:`keras.layers.Input` tensors. This will be a list of 2 tensors (one
for each side) each of shapes :attr:`input_shape`.
Returns
-------
:class:`keras.models.Model`
See Keras documentation for the correct structure, but note that parameter :attr:`name`
is a required rather than an optional argument in Faceswap. You should assign this to
the attribute ``self.name`` that is automatically generated from the plugin's filename.
"""
raise NotImplementedError
def _output_summary(self) -> None:
""" Output the summary of the model and all sub-models to the verbose logger. """
if hasattr(self._args, "summary") and self._args.summary:
print_fn = None # Print straight to stdout
else:
# print to logger
print_fn = lambda x: logger.verbose("%s", x) #type:ignore[attr-defined] # noqa[E731] # pylint:disable=C3001
for idx, model in enumerate(get_all_sub_models(self.model)):
if idx == 0:
parent = model
continue
model.summary(line_length=100, print_fn=print_fn)
parent.summary(line_length=100, print_fn=print_fn)
def save(self, is_exit: bool = False) -> None:
""" Save the model to disk.
Saves the serialized model, with weights, to the folder location specified when
initializing the plugin. If loss has dropped on both sides of the model, then
a backup is taken.
Parameters
----------
is_exit: bool, optional
``True`` if the save request has come from an exit process request otherwise ``False``
Default: ``False``
"""
self._io.save(is_exit=is_exit)
def snapshot(self) -> None:
""" Creates a snapshot of the model folder to the models parent folder, with the number
of iterations completed appended to the end of the model name. """
self._io.snapshot()
def _compile_model(self) -> None:
""" Compile the model to include the Optimizer and Loss Function(s). """
logger.debug("Compiling Model")
if self.state.model_needs_rebuild:
self._model = self._settings.check_model_precision(self._model, self._state)
optimizer = Optimizer(self.config["optimizer"],
self.config["learning_rate"],
self.config["autoclip"],
10 ** int(self.config["epsilon_exponent"])).optimizer
if self._settings.use_mixed_precision:
optimizer = self._settings.loss_scale_optimizer(optimizer)
weights = Weights(self)
weights.load(self._io.model_exists)
weights.freeze()
self._loss.configure(self.model)
self.model.compile(optimizer=optimizer, loss=self._loss.functions)
self._state.add_session_loss_names(self._loss.names)
logger.debug("Compiled Model: %s", self.model)
def _legacy_mapping(self) -> dict | None:
""" The mapping of separate model files to single model layers for transferring of legacy
weights.
Returns
-------
dict or ``None``
Dictionary of original H5 filenames for legacy models mapped to new layer names or
``None`` if the model did not exist in Faceswap prior to Tensorflow 2
"""
return None
def add_history(self, loss: list[float]) -> None:
""" Add the current iteration's loss history to :attr:`_io.history`.
Called from the trainer after each iteration, for tracking loss drop over time between
save iterations.
Parameters
----------
loss: list
The loss values for the A and B side for the current iteration. This should be the
collated loss values for each side.
"""
self._io.history[0].append(loss[0])
self._io.history[1].append(loss[1])
class State():
""" Holds state information relating to the plugin's saved model.
Parameters
----------
model_dir: str
The full path to the model save location
model_name: str
The name of the model plugin
config_changeable_items: dict
Configuration options that can be altered when resuming a model, and their current values
no_logs: bool
``True`` if Tensorboard logs should not be generated, otherwise ``False``
"""
def __init__(self,
model_dir: str,
model_name: str,
config_changeable_items: dict,
no_logs: bool) -> None:
logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', "
"config_changeable_items: '%s', no_logs: %s", self.__class__.__name__,
model_dir, model_name, config_changeable_items, no_logs)
self._serializer = get_serializer("json")
filename = f"{model_name}_state.{self._serializer.file_extension}"
self._filename = os.path.join(model_dir, filename)
self._name = model_name
self._iterations = 0
self._mixed_precision_layers: list[str] = []
self._rebuild_model = False
self._sessions: dict[int, dict] = {}
self._lowest_avg_loss: dict[str, float] = {}
self._config: dict[str, ConfigValueType] = {}
self._load(config_changeable_items)
self._session_id = self._new_session_id()
self._create_new_session(no_logs, config_changeable_items)
logger.debug("Initialized %s:", self.__class__.__name__)
@property
def loss_names(self) -> list[str]:
""" list: The loss names for the current session """
return self._sessions[self._session_id]["loss_names"]
@property
def current_session(self) -> dict:
""" dict: The state dictionary for the current :attr:`session_id`. """
return self._sessions[self._session_id]
@property
def iterations(self) -> int:
""" int: The total number of iterations that the model has trained. """
return self._iterations
@property
def lowest_avg_loss(self) -> dict:
"""dict: The lowest average save interval loss seen for each side. """
return self._lowest_avg_loss
@property
def session_id(self) -> int:
""" int: The current training session id. """
return self._session_id
@property
def mixed_precision_layers(self) -> list[str]:
"""list: Layers that can be switched between mixed-float16 and float32. """
return self._mixed_precision_layers
@property
def model_needs_rebuild(self) -> bool:
"""bool: ``True`` if mixed precision policy has changed so model needs to be rebuilt
otherwise ``False`` """
return self._rebuild_model
def _new_session_id(self) -> int:
""" Generate a new session id. Returns 1 if this is a new model, or the last session id + 1
if it is a pre-existing model.
Returns
-------
int
The newly generated session id
"""
if not self._sessions:
session_id = 1
else:
session_id = max(int(key) for key in self._sessions.keys()) + 1
logger.debug(session_id)
return session_id
def _create_new_session(self, no_logs: bool, config_changeable_items: dict) -> None:
""" Initialize a new session, creating the dictionary entry for the session in
:attr:`_sessions`.
Parameters
----------
no_logs: bool
``True`` if Tensorboard logs should not be generated, otherwise ``False``
config_changeable_items: dict
Configuration options that can be altered when resuming a model, and their current
values
"""
logger.debug("Creating new session. id: %s", self._session_id)
self._sessions[self._session_id] = {"timestamp": time.time(),
"no_logs": no_logs,
"loss_names": [],
"batchsize": 0,
"iterations": 0,
"config": config_changeable_items}
def add_session_loss_names(self, loss_names: list[str]) -> None:
""" Add the session loss names to the sessions dictionary.
The loss names are used for Tensorboard logging
Parameters
----------
loss_names: list
The list of loss names for this session.
"""
logger.debug("Adding session loss_names: %s", loss_names)
self._sessions[self._session_id]["loss_names"] = loss_names
def add_session_batchsize(self, batch_size: int) -> None:
""" Add the session batch size to the sessions dictionary.
Parameters
----------
batch_size: int
The batch size for the current training session
"""
logger.debug("Adding session batch size: %s", batch_size)
self._sessions[self._session_id]["batchsize"] = batch_size
def increment_iterations(self) -> None:
""" Increment :attr:`iterations` and session iterations by 1. """
self._iterations += 1
self._sessions[self._session_id]["iterations"] += 1
def add_mixed_precision_layers(self, layers: list[str]) -> None:
""" Add the list of model's layers that are compatible for mixed precision to the
state dictionary """
logger.debug("Storing mixed precision layers: %s", layers)
self._mixed_precision_layers = layers
def _load(self, config_changeable_items: dict) -> None:
""" Load a state file and set the serialized values to the class instance.
Updates the model's config with the values stored in the state file.
Parameters
----------
config_changeable_items: dict
Configuration options that can be altered when resuming a model, and their current
values
"""
logger.debug("Loading State")
if not os.path.exists(self._filename):
logger.info("No existing state file found. Generating.")
return
state = self._serializer.load(self._filename)
self._name = state.get("name", self._name)
self._sessions = state.get("sessions", {})
self._lowest_avg_loss = state.get("lowest_avg_loss", {})
self._iterations = state.get("iterations", 0)
self._mixed_precision_layers = state.get("mixed_precision_layers", [])
self._config = state.get("config", {})
logger.debug("Loaded state: %s", state)
self._replace_config(config_changeable_items)
def save(self) -> None:
""" Save the state values to the serialized state file. """
logger.debug("Saving State")
state = {"name": self._name,
"sessions": self._sessions,
"lowest_avg_loss": self._lowest_avg_loss,
"iterations": self._iterations,
"mixed_precision_layers": self._mixed_precision_layers,
"config": _CONFIG}
self._serializer.save(self._filename, state)
logger.debug("Saved State")
def _replace_config(self, config_changeable_items) -> None:
""" Replace the loaded config with the one contained within the state file.
Check for any `fixed`=``False`` parameter changes and log info changes.
Update any legacy config items to their current versions.
Parameters
----------
config_changeable_items: dict
Configuration options that can be altered when resuming a model, and their current
values
"""
global _CONFIG # pylint: disable=global-statement
if _CONFIG is None:
return
legacy_update = self._update_legacy_config()
# Add any new items to state config for legacy purposes where the new default may be
# detrimental to an existing model.
legacy_defaults: dict[str, str | int | bool] = {"centering": "legacy",
"mask_loss_function": "mse",
"l2_reg_term": 100,
"optimizer": "adam",
"mixed_precision": False}
for key, val in _CONFIG.items():
if key not in self._config.keys():
setting: ConfigValueType = legacy_defaults.get(key, val)
logger.info("Adding new config item to state file: '%s': '%s'", key, setting)
self._config[key] = setting
self._update_changed_config_items(config_changeable_items)
logger.debug("Replacing config. Old config: %s", _CONFIG)
_CONFIG = self._config
if legacy_update:
self.save()
logger.debug("Replaced config. New config: %s", _CONFIG)
logger.info("Using configuration saved in state file")
def _update_legacy_config(self) -> bool:
""" Legacy updates for new config additions.
When new config items are added to the Faceswap code, existing model state files need to be
updated to handle these new items.
Current existing legacy update items:
* loss - If old `dssim_loss` is ``true`` set new `loss_function` to `ssim` otherwise
set it to `mae`. Remove old `dssim_loss` item
* l2_reg_term - If this exists, set loss_function_2 to ``mse`` and loss_weight_2 to
the value held in the old ``l2_reg_term`` item
* masks - If `learn_mask` does not exist then it is set to ``True`` if `mask_type` is
not ``None`` otherwise it is set to ``False``.
* masks type - Replace removed masks 'dfl_full' and 'facehull' with `components` mask
* clipnorm - Only existed in 2 models (DFL-SAE + Unbalanced). Replaced with global
option autoclip
Returns
-------
bool
``True`` if legacy items exist and state file has been updated, otherwise ``False``
"""
logger.debug("Checking for legacy state file update")
priors = ["dssim_loss", "mask_type", "mask_type", "l2_reg_term", "clipnorm"]
new_items = ["loss_function", "learn_mask", "mask_type", "loss_function_2",
"autoclip"]
updated = False
for old, new in zip(priors, new_items):
if old not in self._config:
logger.debug("Legacy item '%s' not in config. Skipping update", old)
continue
# dssim_loss > loss_function
if old == "dssim_loss":
self._config[new] = "ssim" if self._config[old] else "mae"
del self._config[old]
updated = True
logger.info("Updated config from legacy dssim format. New config loss "
"function: '%s'", self._config[new])
continue
# Add learn mask option and set to True if model has "penalized_mask_loss" specified
if old == "mask_type" and new == "learn_mask" and new not in self._config:
self._config[new] = self._config["mask_type"] is not None
updated = True
logger.info("Added new 'learn_mask' config item for this model. Value set to: %s",
self._config[new])
continue
# Replace removed masks with most similar equivalent
if old == "mask_type" and new == "mask_type" and self._config[old] in ("facehull",
"dfl_full"):
old_mask = self._config[old]
self._config[new] = "components"
updated = True
logger.info("Updated 'mask_type' from '%s' to '%s' for this model",
old_mask, self._config[new])
# Replace l2_reg_term with the correct loss_2_function and update the value of
# loss_2_weight
if old == "l2_reg_term":
self._config[new] = "mse"
self._config["loss_weight_2"] = self._config[old]
del self._config[old]
updated = True
logger.info("Updated config from legacy 'l2_reg_term' to 'loss_function_2'")
# Replace clipnorm with correct gradient clipping type and value
if old == "clipnorm":
self._config[new] = self._config[old]
del self._config[old]
updated = True
logger.info("Updated config from legacy '%s' to '%s'", old, new)
logger.debug("State file updated for legacy config: %s", updated)
return updated
def _update_changed_config_items(self, config_changeable_items: dict) -> None:
""" Update any parameters which are not fixed and have been changed.
Set the :attr:`model_needs_rebuild` to ``True`` if mixed precision state has changed
Parameters
----------
config_changeable_items: dict
Configuration options that can be altered when resuming a model, and their current
values
"""
rebuild_tasks = ["mixed_precision"]
if not config_changeable_items:
logger.debug("No changeable parameters have been updated")
return
for key, val in config_changeable_items.items():
old_val = self._config[key]
if old_val == val:
continue
self._config[key] = val
logger.info("Config item: '%s' has been updated from '%s' to '%s'", key, old_val, val)
self._rebuild_model = not self._rebuild_model and key in rebuild_tasks
class _Inference(): # pylint:disable=too-few-public-methods
""" Calculates required layers and compiles a saved model for inference.
Parameters
----------
saved_model: :class:`keras.models.Model`
The saved trained Faceswap model
switch_sides: bool
``True`` if the swap should be performed "B" > "A" ``False`` if the swap should be
"A" > "B"
"""
def __init__(self, saved_model: tf.keras.models.Model, switch_sides: bool) -> None:
logger.debug("Initializing: %s (saved_model: %s, switch_sides: %s)",
self.__class__.__name__, saved_model, switch_sides)
self._config = saved_model.get_config()
self._input_idx = 1 if switch_sides else 0
self._output_idx = 0 if switch_sides else 1
self._input_names = [inp[0] for inp in self._config["input_layers"]]
self._model = self._make_inference_model(saved_model)
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def model(self) -> tf.keras.models.Model:
""" :class:`keras.models.Model`: The Faceswap model, compiled for inference. """
return self._model
def _get_nodes(self, nodes: np.ndarray) -> list[tuple[str, int]]:
""" Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary,
filters the layer name(s) and output index of the node, splitting to the correct output
index in the event of multiple inputs.
Parameters
----------
nodes: list
A node entry from the :attr:`keras.models.Model.get_config` dictionary
Returns
-------
list
The (node name, output index) for each node passed in
"""
anodes = np.array(nodes, dtype="object")[..., :3]
num_layers = anodes.shape[0]
anodes = anodes[self._output_idx] if num_layers == 2 else anodes[0]
# Probably better checks for this, but this occurs when DNY preset is used and learn
# mask is enabled (i.e. the mask is created in fully connected layers)
anodes = anodes.squeeze() if anodes.ndim == 3 else anodes
retval = [(node[0], node[2]) for node in anodes]
return retval
def _make_inference_model(self, saved_model: tf.keras.models.Model) -> tf.keras.models.Model:
""" Extract the sub-models from the saved model that are required for inference.
Parameters
----------
saved_model: :class:`keras.models.Model`
The saved trained Faceswap model
Returns
-------
:class:`keras.models.Model`
The model compiled for inference
"""
logger.debug("Compiling inference model. saved_model: %s", saved_model)
struct = self._get_filtered_structure()
model_inputs = self._get_inputs(saved_model.inputs)
compiled_layers: dict[str, tf.keras.layers.Layer] = {}
for layer in saved_model.layers:
if layer.name not in struct:
logger.debug("Skipping unused layer: '%s'", layer.name)
continue
inbound = struct[layer.name]
logger.debug("Processing layer '%s': (layer: %s, inbound_nodes: %s)",
layer.name, layer, inbound)
if not inbound:
model = model_inputs
logger.debug("Adding model inputs %s: %s", layer.name, model)
else:
layer_inputs = []
for inp in inbound:
inbound_layer = compiled_layers[inp[0]]
if isinstance(inbound_layer, list) and len(inbound_layer) > 1:
# Multi output inputs
inbound_output_idx = inp[1]
next_input = inbound_layer[inbound_output_idx]
logger.debug("Selecting output index %s from multi output inbound layer: "
"%s (using: %s)", inbound_output_idx, inbound_layer,
next_input)
else:
next_input = inbound_layer
layer_inputs.append(next_input)
logger.debug("Compiling layer '%s': layer inputs: %s", layer.name, layer_inputs)
model = layer(layer_inputs)
compiled_layers[layer.name] = model
retval = KModel(model_inputs, model, name=f"{saved_model.name}_inference")
logger.debug("Compiled inference model '%s': %s", retval.name, retval)
return retval
def _get_filtered_structure(self) -> OrderedDict:
""" Obtain the structure of the inference model.
This parses the model config (in reverse) to obtain the required layers for an inference
model.
Returns
-------
:class:`collections.OrderedDict`
The layer name as key with the input name and output index as value.
"""
# Filter output layer
out = np.array(self._config["output_layers"], dtype="object")
if out.ndim == 2:
out = np.expand_dims(out, axis=1) # Needs to be expanded for _get_nodes
outputs = self._get_nodes(out)
# Iterate backwards from the required output to get the reversed model structure
current_layers = [outputs[0]]
next_layers = []
struct = OrderedDict()
drop_input = self._input_names[abs(self._input_idx - 1)]
switch_input = self._input_names[self._input_idx]
while True:
layer_info = current_layers.pop(0)
current_layer = next(lyr for lyr in self._config["layers"]
if lyr["name"] == layer_info[0])
inbound = current_layer["inbound_nodes"]
if not inbound:
break
inbound_info = self._get_nodes(inbound)
if any(inb[0] == drop_input for inb in inbound_info): # Switch inputs
inbound_info = [(switch_input if inb[0] == drop_input else inb[0], inb[1])
for inb in inbound_info]
struct[layer_info[0]] = inbound_info
next_layers.extend(inbound_info)
if not current_layers:
current_layers = next_layers
next_layers = []
struct[switch_input] = [] # Add the input layer
logger.debug("Model structure: %s", struct)
return struct
def _get_inputs(self, inputs: list) -> list:
""" Obtain the inputs for the requested swap direction.
Parameters
----------
inputs: list
The full list of input tensors to the saved faceswap training model
Returns
-------
list
List of input tensors to feed the model for the requested swap direction
"""
input_split = len(inputs) // 2
start_idx = input_split * self._input_idx
retval = inputs[start_idx: start_idx + input_split]
logger.debug("model inputs: %s, input_split: %s, start_idx: %s, inference_inputs: %s",
inputs, input_split, start_idx, retval)
return retval