1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-06 17:45:56 -04:00

Centralize model storage

This commit is contained in:
torzdf 2022-06-18 19:54:02 +01:00
parent ef79a3d8cb
commit f2e6f24651
8 changed files with 22 additions and 52 deletions

1
.gitignore vendored
View file

@ -36,6 +36,7 @@
!tests/**/*.py
# Core files
!.fs_cache
!lib/
!lib/**/
!lib/**/*.py

View file

@ -370,7 +370,7 @@ class FaceswapError(Exception):
class GetModel(): # pylint:disable=too-few-public-methods
""" Check for models in their cache path.
""" Check for models in the cache path.
If available, return the path, if not available, get, unzip and install model
@ -378,9 +378,6 @@ class GetModel(): # pylint:disable=too-few-public-methods
----------
model_filename: str or list
The name of the model to be loaded (see notes below)
cache_dir: str
The model cache folder of the current plugin calling this class. IE: The folder that holds
the model to be loaded.
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
@ -397,12 +394,12 @@ class GetModel(): # pylint:disable=too-few-public-methods
,"resnet_ssd_v1.prototext"]`
"""
def __init__(self, model_filename, cache_dir, git_model_id):
def __init__(self, model_filename, git_model_id):
self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list):
model_filename = [model_filename]
self._model_filename = model_filename
self._cache_dir = cache_dir
self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
self._git_model_id = git_model_id
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
self._chunk_size = 1024 # Chunk size for downloading and unzipping
@ -456,27 +453,10 @@ class GetModel(): # pylint:disable=too-few-public-methods
self.logger.trace(retval)
return retval
@property
def _plugin_section(self):
""" str: The plugin section from the config_dir """
path = os.path.normpath(self._cache_dir)
split = path.split(os.sep)
retval = split[split.index("plugins") + 1]
self.logger.trace(retval)
return retval
@property
def _url_section(self):
""" int: The section ID in github for this plugin type. """
sections = dict(extract=1, train=2, convert=3)
retval = sections[self._plugin_section]
self.logger.trace(retval)
return retval
@property
def _url_download(self):
""" strL Base download URL for models. """
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}"
tag = f"v{self._git_model_id}.{self._model_version}"
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
self.logger.trace("Download url: %s", retval)
return retval

View file

@ -7,8 +7,6 @@ https://creativecommons.org/licenses/by-nc/4.0/
"""
import logging
import sys
import os
import cv2
import numpy as np
@ -37,9 +35,7 @@ class VGGFace():
# <<< GET MODEL >>> #
def get_model(self, git_model_id, model_filename, backend):
""" Check if model is available, if not, download and unzip it """
root_path = os.path.abspath(os.path.dirname(sys.argv[0]))
cache_path = os.path.join(root_path, "plugins", "extract", "recognition", ".cache")
model = GetModel(model_filename, cache_path, git_model_id).model_path
model = GetModel(model_filename, git_model_id).model_path
model = cv2.dnn.readNetFromCaffe(model[1], model[0])
model.setPreferableTarget(self.get_backend(backend))
return model
@ -50,7 +46,7 @@ class VGGFace():
if backend == "OPENCL":
logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of "
"the failure messages.")
retval = getattr(cv2.dnn, "DNN_TARGET_{}".format(backend))
retval = getattr(cv2.dnn, f"DNN_TARGET_{backend}")
return retval
def predict(self, face):

View file

@ -3,8 +3,6 @@
:mod:`~plugins.extract.mask` Plugins
"""
import logging
import os
import sys
from tensorflow.python.framework import errors_impl as tf_errors
@ -139,13 +137,13 @@ class Extractor():
""" int: Batchsize for feeding this model. The number of images the model should
feed through at once. """
self._queues = dict()
self._queues = {}
""" dict: in + out queues and internal queues for this plugin, """
self._threads = []
""" list: Internal threads for this plugin """
self._extract_media = dict()
self._extract_media = {}
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
processed. Stored at input for pairing back up on output of extractor process """
@ -352,7 +350,8 @@ class Extractor():
# <<< PROTECTED ACCESS METHODS >>> #
# <<< INIT METHODS >>> #
def _get_model(self, git_model_id, model_filename):
@classmethod
def _get_model(cls, git_model_id, model_filename):
""" Check if model is available, if not, download and unzip it """
if model_filename is None:
logger.debug("No model_filename specified. Returning None")
@ -360,13 +359,7 @@ class Extractor():
if git_model_id is None:
logger.debug("No git_model_id specified. Returning None")
return None
plugin_path = os.path.join(*self.__module__.split(".")[:-1])
if os.path.basename(plugin_path) in ("detect", "align", "mask", "recognition"):
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
cache_path = os.path.join(base_path, plugin_path, ".cache")
else:
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
model = GetModel(model_filename, cache_path, git_model_id)
model = GetModel(model_filename, git_model_id)
return model.model_path
# <<< PLUGIN INITIALIZATION >>> #
@ -382,7 +375,7 @@ class Extractor():
name = self.name.replace(" ", "_").lower()
self._add_queues(kwargs["in_queue"],
kwargs["out_queue"],
["predict_{}".format(name), "post_{}".format(name)])
[f"predict_{name}", f"post_{name}"])
self._compile_threads()
try:
self.init_model()
@ -409,7 +402,7 @@ class Extractor():
self._queues["out"] = out_queue
for q_name in queues:
self._queues[q_name] = queue_manager.get_queue(
name="{}{}_{}".format(self._plugin_type, self._instance, q_name),
name=f"{self._plugin_type}{self._instance}_{q_name}",
maxsize=self.queue_size)
# <<< THREAD METHODS >>> #
@ -417,18 +410,18 @@ class Extractor():
""" Compile the threads into self._threads list """
logger.debug("Compiling %s threads", self._plugin_type)
name = self.name.replace(" ", "_").lower()
base_name = "{}_{}".format(self._plugin_type, name)
self._add_thread("{}_input".format(base_name),
base_name = f"{self._plugin_type}_{name}"
self._add_thread(f"{base_name}_input",
self._process_input,
self._queues["in"],
self._queues["predict_{}".format(name)])
self._add_thread("{}_predict".format(base_name),
self._queues[f"predict_{name}"])
self._add_thread(f"{base_name}_predict",
self._predict,
self._queues["predict_{}".format(name)],
self._queues["post_{}".format(name)])
self._add_thread("{}_output".format(base_name),
self._queues[f"predict_{name}"],
self._queues[f"post_{name}"])
self._add_thread(f"{base_name}_output",
self._process_output,
self._queues["post_{}".format(name)],
self._queues[f"post_{name}"],
self._queues["out"])
logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads)