1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/lib/model/session.py
kvrooman b49c352e8f # pylint:disable=no-member cleanup (#927)
# pylint:disable=no-member cleanup
2019-11-13 12:17:59 +00:00

192 lines
8.3 KiB
Python

#!/usr/bin python3
""" Settings manager for Keras Backend """
import logging
import tensorflow as tf
from keras.layers import Activation
from tensorflow.python import errors_impl as tf_error # pylint:disable=no-name-in-module
from keras.models import load_model as k_load_model, Model
import numpy as np
from lib.utils import get_backend, FaceswapError
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
class KSession():
""" Handles the settings of backend sessions.
This class acts as a wrapper for various :class:`keras.Model()` functions, ensuring that
actions performed on a model are handled consistently within the correct graph.
This is an early implementation of this class, and should be expanded out over time
with relevant `AMD`, `CPU` and `NVIDIA` backend methods.
Parameters
----------
name: str
The name of the model that is to be loaded
model_path: str
The path to the keras model file
model_kwargs: dict, optional
Any kwargs that need to be passed to :func:`keras.models.load_models()`. Default: None
allow_growth: bool, optional
Enable the Tensorflow GPU allow_growth configuration option. This option prevents "
Tensorflow from allocating all of the GPU VRAM, but can lead to higher fragmentation and "
slower performance. Default: False
"""
def __init__(self, name, model_path, model_kwargs=None, allow_growth=False):
logger.trace("Initializing: %s (name: %s, model_path: %s, model_kwargs: %s, "
"allow_growth: %s)",
self.__class__.__name__, name, model_path, model_kwargs, allow_growth)
self._name = name
self._session = self._set_session(allow_growth)
self._model_path = model_path
self._model_kwargs = model_kwargs
self._model = None
logger.trace("Initialized: %s", self.__class__.__name__,)
def predict(self, feed, batch_size=None):
""" Get predictions from the model in the correct session.
This method is a wrapper for :func:`keras.predict()` function.
Parameters
----------
feed: numpy.ndarray or list
The feed to be provided to the model as input. This should be a ``numpy.ndarray``
for single inputs or a ``list`` of ``numpy.ndarrays`` for multiple inputs.
"""
if self._session is None:
if batch_size is None:
return self._model.predict(feed)
return self._amd_predict_with_optimized_batchsizes(feed, batch_size)
with self._session.as_default(): # pylint: disable=not-context-manager
with self._session.graph.as_default():
return self._model.predict(feed, batch_size=batch_size)
def _amd_predict_with_optimized_batchsizes(self, feed, batch_size):
""" Minimizes the amount of kernels to be compiled when using
the ``Amd`` backend with varying batchsizes while trying to keep
the batchsize as high as possible.
Parameters
----------
feed: numpy.ndarray or list
The feed to be provided to the model as input. This should be a ``numpy.ndarray``
for single inputs or a ``list`` of ``numpy.ndarrays`` for multiple inputs.
batch_size: int
The upper batchsize to use.
"""
if isinstance(feed, np.ndarray):
feed = [feed]
items = feed[0].shape[0]
done_items = 0
results = list()
while done_items < items:
if batch_size < 4: # Not much difference in BS < 4
batch_size = 1
batch_items = ((items - done_items) // batch_size) * batch_size
if batch_items:
pred_data = [x[done_items:done_items + batch_items] for x in feed]
pred = self._model.predict(pred_data, batch_size=batch_size)
done_items += batch_items
results.append(pred)
batch_size //= 2
if isinstance(results[0], np.ndarray):
return np.concatenate(results)
return [np.concatenate(x) for x in zip(*results)]
def _set_session(self, allow_growth):
""" Sets the session and graph.
If the backend is AMD then this does nothing and the global ``Keras`` ``Session``
is used
"""
if get_backend() == "amd":
return None
self.graph = tf.Graph()
config = tf.ConfigProto()
if allow_growth and get_backend() == "nvidia":
config.gpu_options.allow_growth = True
try:
session = tf.Session(graph=tf.Graph(), config=config)
except tf_error.InternalError as err:
if "driver version is insufficient" in str(err):
msg = ("Your Nvidia Graphics Driver is insufficient for running Faceswap. "
"Please upgrade to the latest version.")
raise FaceswapError(msg) from err
raise err
logger.debug("Created tf.session: (graph: %s, session: %s, config: %s)",
session.graph, session, config)
return session
def load_model(self):
""" Loads a model within the correct session.
This method is a wrapper for :func:`keras.models.load_model()`. Loads a model and its
weights from :attr:`model_path`. Any additional ``kwargs`` to be passed to
:func:`keras.models.load_model()` should also be defined during initialization of the
class.
"""
logger.verbose("Initializing plugin model: %s", self._name)
if self._session is None:
self._model = k_load_model(self._model_path, **self._model_kwargs)
else:
with self._session.as_default(): # pylint: disable=not-context-manager
with self._session.graph.as_default():
self._model = k_load_model(self._model_path, **self._model_kwargs)
def define_model(self, function):
""" Defines a given model in the correct session.
This method acts as a wrapper for :class:`keras.models.Model()` to ensure that the model
is defined within it's own graph.
Parameters
----------
function: function
A function that defines a :class:`keras.Model` and returns it's ``inputs`` and
``outputs``. The function that generates these results should be passed in, NOT the
results themselves, as the function needs to be executed within the correct context.
"""
if self._session is None:
self._model = Model(*function())
else:
with self._session.as_default(): # pylint: disable=not-context-manager
with self._session.graph.as_default():
self._model = Model(*function())
def load_model_weights(self):
""" Load model weights for a defined model inside the correct session.
This method is a wrapper for :class:`keras.load_weights()`. Once a model has been defined
in :func:`define_model()` this method can be called to load its weights in the correct
graph from the :attr:`model_path` defined during initialization of this class.
"""
logger.verbose("Initializing plugin model: %s", self._name)
if self._session is None:
self._model.load_weights(self._model_path)
else:
with self._session.as_default(): # pylint: disable=not-context-manager
with self._session.graph.as_default():
self._model.load_weights(self._model_path)
def append_softmax_activation(self, layer_index=-1):
""" Append a softmax activation layer to a model
Occasionally a softmax activation layer needs to be added to a model's output.
This is a convenience fuction to append this layer to the loaded model.
Parameters
----------
layer_index: int, optional
The layer index of the model to select the output from to use as an input to the
softmax activation layer. Default: -1 (The final layer of the model)
"""
logger.debug("Appending Softmax Activation to model: (layer_index: %s)", layer_index)
softmax = Activation("softmax", name="softmax")(self._model.layers[layer_index].output)
self._model = Model(inputs=self._model.input, outputs=[softmax])