1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00

Centralize model storage

This commit is contained in:
torzdf 2022-06-18 19:54:02 +01:00
parent ef79a3d8cb
commit f2e6f24651
8 changed files with 22 additions and 52 deletions

1
.gitignore vendored
View file

@ -36,6 +36,7 @@
!tests/**/*.py !tests/**/*.py
# Core files # Core files
!.fs_cache
!lib/ !lib/
!lib/**/ !lib/**/
!lib/**/*.py !lib/**/*.py

View file

@ -370,7 +370,7 @@ class FaceswapError(Exception):
class GetModel(): # pylint:disable=too-few-public-methods class GetModel(): # pylint:disable=too-few-public-methods
""" Check for models in their cache path. """ Check for models in the cache path.
If available, return the path, if not available, get, unzip and install model If available, return the path, if not available, get, unzip and install model
@ -378,9 +378,6 @@ class GetModel(): # pylint:disable=too-few-public-methods
---------- ----------
model_filename: str or list model_filename: str or list
The name of the model to be loaded (see notes below) The name of the model to be loaded (see notes below)
cache_dir: str
The model cache folder of the current plugin calling this class. IE: The folder that holds
the model to be loaded.
git_model_id: int git_model_id: int
The second digit in the github tag that identifies this model. See The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information https://github.com/deepfakes-models/faceswap-models for more information
@ -397,12 +394,12 @@ class GetModel(): # pylint:disable=too-few-public-methods
,"resnet_ssd_v1.prototext"]` ,"resnet_ssd_v1.prototext"]`
""" """
def __init__(self, model_filename, cache_dir, git_model_id): def __init__(self, model_filename, git_model_id):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list): if not isinstance(model_filename, list):
model_filename = [model_filename] model_filename = [model_filename]
self._model_filename = model_filename self._model_filename = model_filename
self._cache_dir = cache_dir self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
self._git_model_id = git_model_id self._git_model_id = git_model_id
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download" self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
self._chunk_size = 1024 # Chunk size for downloading and unzipping self._chunk_size = 1024 # Chunk size for downloading and unzipping
@ -456,27 +453,10 @@ class GetModel(): # pylint:disable=too-few-public-methods
self.logger.trace(retval) self.logger.trace(retval)
return retval return retval
@property
def _plugin_section(self):
""" str: The plugin section from the config_dir """
path = os.path.normpath(self._cache_dir)
split = path.split(os.sep)
retval = split[split.index("plugins") + 1]
self.logger.trace(retval)
return retval
@property
def _url_section(self):
""" int: The section ID in github for this plugin type. """
sections = dict(extract=1, train=2, convert=3)
retval = sections[self._plugin_section]
self.logger.trace(retval)
return retval
@property @property
def _url_download(self): def _url_download(self):
""" strL Base download URL for models. """ """ strL Base download URL for models. """
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}" tag = f"v{self._git_model_id}.{self._model_version}"
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip" retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
self.logger.trace("Download url: %s", retval) self.logger.trace("Download url: %s", retval)
return retval return retval

View file

@ -7,8 +7,6 @@ https://creativecommons.org/licenses/by-nc/4.0/
""" """
import logging import logging
import sys
import os
import cv2 import cv2
import numpy as np import numpy as np
@ -37,9 +35,7 @@ class VGGFace():
# <<< GET MODEL >>> # # <<< GET MODEL >>> #
def get_model(self, git_model_id, model_filename, backend): def get_model(self, git_model_id, model_filename, backend):
""" Check if model is available, if not, download and unzip it """ """ Check if model is available, if not, download and unzip it """
root_path = os.path.abspath(os.path.dirname(sys.argv[0])) model = GetModel(model_filename, git_model_id).model_path
cache_path = os.path.join(root_path, "plugins", "extract", "recognition", ".cache")
model = GetModel(model_filename, cache_path, git_model_id).model_path
model = cv2.dnn.readNetFromCaffe(model[1], model[0]) model = cv2.dnn.readNetFromCaffe(model[1], model[0])
model.setPreferableTarget(self.get_backend(backend)) model.setPreferableTarget(self.get_backend(backend))
return model return model
@ -50,7 +46,7 @@ class VGGFace():
if backend == "OPENCL": if backend == "OPENCL":
logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of " logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of "
"the failure messages.") "the failure messages.")
retval = getattr(cv2.dnn, "DNN_TARGET_{}".format(backend)) retval = getattr(cv2.dnn, f"DNN_TARGET_{backend}")
return retval return retval
def predict(self, face): def predict(self, face):

View file

@ -3,8 +3,6 @@
:mod:`~plugins.extract.mask` Plugins :mod:`~plugins.extract.mask` Plugins
""" """
import logging import logging
import os
import sys
from tensorflow.python.framework import errors_impl as tf_errors from tensorflow.python.framework import errors_impl as tf_errors
@ -139,13 +137,13 @@ class Extractor():
""" int: Batchsize for feeding this model. The number of images the model should """ int: Batchsize for feeding this model. The number of images the model should
feed through at once. """ feed through at once. """
self._queues = dict() self._queues = {}
""" dict: in + out queues and internal queues for this plugin, """ """ dict: in + out queues and internal queues for this plugin, """
self._threads = [] self._threads = []
""" list: Internal threads for this plugin """ """ list: Internal threads for this plugin """
self._extract_media = dict() self._extract_media = {}
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being """ dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
processed. Stored at input for pairing back up on output of extractor process """ processed. Stored at input for pairing back up on output of extractor process """
@ -352,7 +350,8 @@ class Extractor():
# <<< PROTECTED ACCESS METHODS >>> # # <<< PROTECTED ACCESS METHODS >>> #
# <<< INIT METHODS >>> # # <<< INIT METHODS >>> #
def _get_model(self, git_model_id, model_filename): @classmethod
def _get_model(cls, git_model_id, model_filename):
""" Check if model is available, if not, download and unzip it """ """ Check if model is available, if not, download and unzip it """
if model_filename is None: if model_filename is None:
logger.debug("No model_filename specified. Returning None") logger.debug("No model_filename specified. Returning None")
@ -360,13 +359,7 @@ class Extractor():
if git_model_id is None: if git_model_id is None:
logger.debug("No git_model_id specified. Returning None") logger.debug("No git_model_id specified. Returning None")
return None return None
plugin_path = os.path.join(*self.__module__.split(".")[:-1]) model = GetModel(model_filename, git_model_id)
if os.path.basename(plugin_path) in ("detect", "align", "mask", "recognition"):
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
cache_path = os.path.join(base_path, plugin_path, ".cache")
else:
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
model = GetModel(model_filename, cache_path, git_model_id)
return model.model_path return model.model_path
# <<< PLUGIN INITIALIZATION >>> # # <<< PLUGIN INITIALIZATION >>> #
@ -382,7 +375,7 @@ class Extractor():
name = self.name.replace(" ", "_").lower() name = self.name.replace(" ", "_").lower()
self._add_queues(kwargs["in_queue"], self._add_queues(kwargs["in_queue"],
kwargs["out_queue"], kwargs["out_queue"],
["predict_{}".format(name), "post_{}".format(name)]) [f"predict_{name}", f"post_{name}"])
self._compile_threads() self._compile_threads()
try: try:
self.init_model() self.init_model()
@ -409,7 +402,7 @@ class Extractor():
self._queues["out"] = out_queue self._queues["out"] = out_queue
for q_name in queues: for q_name in queues:
self._queues[q_name] = queue_manager.get_queue( self._queues[q_name] = queue_manager.get_queue(
name="{}{}_{}".format(self._plugin_type, self._instance, q_name), name=f"{self._plugin_type}{self._instance}_{q_name}",
maxsize=self.queue_size) maxsize=self.queue_size)
# <<< THREAD METHODS >>> # # <<< THREAD METHODS >>> #
@ -417,18 +410,18 @@ class Extractor():
""" Compile the threads into self._threads list """ """ Compile the threads into self._threads list """
logger.debug("Compiling %s threads", self._plugin_type) logger.debug("Compiling %s threads", self._plugin_type)
name = self.name.replace(" ", "_").lower() name = self.name.replace(" ", "_").lower()
base_name = "{}_{}".format(self._plugin_type, name) base_name = f"{self._plugin_type}_{name}"
self._add_thread("{}_input".format(base_name), self._add_thread(f"{base_name}_input",
self._process_input, self._process_input,
self._queues["in"], self._queues["in"],
self._queues["predict_{}".format(name)]) self._queues[f"predict_{name}"])
self._add_thread("{}_predict".format(base_name), self._add_thread(f"{base_name}_predict",
self._predict, self._predict,
self._queues["predict_{}".format(name)], self._queues[f"predict_{name}"],
self._queues["post_{}".format(name)]) self._queues[f"post_{name}"])
self._add_thread("{}_output".format(base_name), self._add_thread(f"{base_name}_output",
self._process_output, self._process_output,
self._queues["post_{}".format(name)], self._queues[f"post_{name}"],
self._queues["out"]) self._queues["out"])
logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads) logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads)