mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
Centralize model storage
This commit is contained in:
parent
ef79a3d8cb
commit
f2e6f24651
8 changed files with 22 additions and 52 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -36,6 +36,7 @@
|
||||||
!tests/**/*.py
|
!tests/**/*.py
|
||||||
|
|
||||||
# Core files
|
# Core files
|
||||||
|
!.fs_cache
|
||||||
!lib/
|
!lib/
|
||||||
!lib/**/
|
!lib/**/
|
||||||
!lib/**/*.py
|
!lib/**/*.py
|
||||||
|
|
28
lib/utils.py
28
lib/utils.py
|
@ -370,7 +370,7 @@ class FaceswapError(Exception):
|
||||||
|
|
||||||
|
|
||||||
class GetModel(): # pylint:disable=too-few-public-methods
|
class GetModel(): # pylint:disable=too-few-public-methods
|
||||||
""" Check for models in their cache path.
|
""" Check for models in the cache path.
|
||||||
|
|
||||||
If available, return the path, if not available, get, unzip and install model
|
If available, return the path, if not available, get, unzip and install model
|
||||||
|
|
||||||
|
@ -378,9 +378,6 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
||||||
----------
|
----------
|
||||||
model_filename: str or list
|
model_filename: str or list
|
||||||
The name of the model to be loaded (see notes below)
|
The name of the model to be loaded (see notes below)
|
||||||
cache_dir: str
|
|
||||||
The model cache folder of the current plugin calling this class. IE: The folder that holds
|
|
||||||
the model to be loaded.
|
|
||||||
git_model_id: int
|
git_model_id: int
|
||||||
The second digit in the github tag that identifies this model. See
|
The second digit in the github tag that identifies this model. See
|
||||||
https://github.com/deepfakes-models/faceswap-models for more information
|
https://github.com/deepfakes-models/faceswap-models for more information
|
||||||
|
@ -397,12 +394,12 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
||||||
,"resnet_ssd_v1.prototext"]`
|
,"resnet_ssd_v1.prototext"]`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_filename, cache_dir, git_model_id):
|
def __init__(self, model_filename, git_model_id):
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
if not isinstance(model_filename, list):
|
if not isinstance(model_filename, list):
|
||||||
model_filename = [model_filename]
|
model_filename = [model_filename]
|
||||||
self._model_filename = model_filename
|
self._model_filename = model_filename
|
||||||
self._cache_dir = cache_dir
|
self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
|
||||||
self._git_model_id = git_model_id
|
self._git_model_id = git_model_id
|
||||||
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
|
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
|
||||||
self._chunk_size = 1024 # Chunk size for downloading and unzipping
|
self._chunk_size = 1024 # Chunk size for downloading and unzipping
|
||||||
|
@ -456,27 +453,10 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
||||||
self.logger.trace(retval)
|
self.logger.trace(retval)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
@property
|
|
||||||
def _plugin_section(self):
|
|
||||||
""" str: The plugin section from the config_dir """
|
|
||||||
path = os.path.normpath(self._cache_dir)
|
|
||||||
split = path.split(os.sep)
|
|
||||||
retval = split[split.index("plugins") + 1]
|
|
||||||
self.logger.trace(retval)
|
|
||||||
return retval
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _url_section(self):
|
|
||||||
""" int: The section ID in github for this plugin type. """
|
|
||||||
sections = dict(extract=1, train=2, convert=3)
|
|
||||||
retval = sections[self._plugin_section]
|
|
||||||
self.logger.trace(retval)
|
|
||||||
return retval
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _url_download(self):
|
def _url_download(self):
|
||||||
""" strL Base download URL for models. """
|
""" strL Base download URL for models. """
|
||||||
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}"
|
tag = f"v{self._git_model_id}.{self._model_version}"
|
||||||
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
|
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
|
||||||
self.logger.trace("Download url: %s", retval)
|
self.logger.trace("Download url: %s", retval)
|
||||||
return retval
|
return retval
|
||||||
|
|
|
@ -7,8 +7,6 @@ https://creativecommons.org/licenses/by-nc/4.0/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -37,9 +35,7 @@ class VGGFace():
|
||||||
# <<< GET MODEL >>> #
|
# <<< GET MODEL >>> #
|
||||||
def get_model(self, git_model_id, model_filename, backend):
|
def get_model(self, git_model_id, model_filename, backend):
|
||||||
""" Check if model is available, if not, download and unzip it """
|
""" Check if model is available, if not, download and unzip it """
|
||||||
root_path = os.path.abspath(os.path.dirname(sys.argv[0]))
|
model = GetModel(model_filename, git_model_id).model_path
|
||||||
cache_path = os.path.join(root_path, "plugins", "extract", "recognition", ".cache")
|
|
||||||
model = GetModel(model_filename, cache_path, git_model_id).model_path
|
|
||||||
model = cv2.dnn.readNetFromCaffe(model[1], model[0])
|
model = cv2.dnn.readNetFromCaffe(model[1], model[0])
|
||||||
model.setPreferableTarget(self.get_backend(backend))
|
model.setPreferableTarget(self.get_backend(backend))
|
||||||
return model
|
return model
|
||||||
|
@ -50,7 +46,7 @@ class VGGFace():
|
||||||
if backend == "OPENCL":
|
if backend == "OPENCL":
|
||||||
logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of "
|
logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of "
|
||||||
"the failure messages.")
|
"the failure messages.")
|
||||||
retval = getattr(cv2.dnn, "DNN_TARGET_{}".format(backend))
|
retval = getattr(cv2.dnn, f"DNN_TARGET_{backend}")
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def predict(self, face):
|
def predict(self, face):
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
:mod:`~plugins.extract.mask` Plugins
|
:mod:`~plugins.extract.mask` Plugins
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from tensorflow.python.framework import errors_impl as tf_errors
|
from tensorflow.python.framework import errors_impl as tf_errors
|
||||||
|
|
||||||
|
@ -139,13 +137,13 @@ class Extractor():
|
||||||
""" int: Batchsize for feeding this model. The number of images the model should
|
""" int: Batchsize for feeding this model. The number of images the model should
|
||||||
feed through at once. """
|
feed through at once. """
|
||||||
|
|
||||||
self._queues = dict()
|
self._queues = {}
|
||||||
""" dict: in + out queues and internal queues for this plugin, """
|
""" dict: in + out queues and internal queues for this plugin, """
|
||||||
|
|
||||||
self._threads = []
|
self._threads = []
|
||||||
""" list: Internal threads for this plugin """
|
""" list: Internal threads for this plugin """
|
||||||
|
|
||||||
self._extract_media = dict()
|
self._extract_media = {}
|
||||||
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
|
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
|
||||||
processed. Stored at input for pairing back up on output of extractor process """
|
processed. Stored at input for pairing back up on output of extractor process """
|
||||||
|
|
||||||
|
@ -352,7 +350,8 @@ class Extractor():
|
||||||
|
|
||||||
# <<< PROTECTED ACCESS METHODS >>> #
|
# <<< PROTECTED ACCESS METHODS >>> #
|
||||||
# <<< INIT METHODS >>> #
|
# <<< INIT METHODS >>> #
|
||||||
def _get_model(self, git_model_id, model_filename):
|
@classmethod
|
||||||
|
def _get_model(cls, git_model_id, model_filename):
|
||||||
""" Check if model is available, if not, download and unzip it """
|
""" Check if model is available, if not, download and unzip it """
|
||||||
if model_filename is None:
|
if model_filename is None:
|
||||||
logger.debug("No model_filename specified. Returning None")
|
logger.debug("No model_filename specified. Returning None")
|
||||||
|
@ -360,13 +359,7 @@ class Extractor():
|
||||||
if git_model_id is None:
|
if git_model_id is None:
|
||||||
logger.debug("No git_model_id specified. Returning None")
|
logger.debug("No git_model_id specified. Returning None")
|
||||||
return None
|
return None
|
||||||
plugin_path = os.path.join(*self.__module__.split(".")[:-1])
|
model = GetModel(model_filename, git_model_id)
|
||||||
if os.path.basename(plugin_path) in ("detect", "align", "mask", "recognition"):
|
|
||||||
base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
|
|
||||||
cache_path = os.path.join(base_path, plugin_path, ".cache")
|
|
||||||
else:
|
|
||||||
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
|
|
||||||
model = GetModel(model_filename, cache_path, git_model_id)
|
|
||||||
return model.model_path
|
return model.model_path
|
||||||
|
|
||||||
# <<< PLUGIN INITIALIZATION >>> #
|
# <<< PLUGIN INITIALIZATION >>> #
|
||||||
|
@ -382,7 +375,7 @@ class Extractor():
|
||||||
name = self.name.replace(" ", "_").lower()
|
name = self.name.replace(" ", "_").lower()
|
||||||
self._add_queues(kwargs["in_queue"],
|
self._add_queues(kwargs["in_queue"],
|
||||||
kwargs["out_queue"],
|
kwargs["out_queue"],
|
||||||
["predict_{}".format(name), "post_{}".format(name)])
|
[f"predict_{name}", f"post_{name}"])
|
||||||
self._compile_threads()
|
self._compile_threads()
|
||||||
try:
|
try:
|
||||||
self.init_model()
|
self.init_model()
|
||||||
|
@ -409,7 +402,7 @@ class Extractor():
|
||||||
self._queues["out"] = out_queue
|
self._queues["out"] = out_queue
|
||||||
for q_name in queues:
|
for q_name in queues:
|
||||||
self._queues[q_name] = queue_manager.get_queue(
|
self._queues[q_name] = queue_manager.get_queue(
|
||||||
name="{}{}_{}".format(self._plugin_type, self._instance, q_name),
|
name=f"{self._plugin_type}{self._instance}_{q_name}",
|
||||||
maxsize=self.queue_size)
|
maxsize=self.queue_size)
|
||||||
|
|
||||||
# <<< THREAD METHODS >>> #
|
# <<< THREAD METHODS >>> #
|
||||||
|
@ -417,18 +410,18 @@ class Extractor():
|
||||||
""" Compile the threads into self._threads list """
|
""" Compile the threads into self._threads list """
|
||||||
logger.debug("Compiling %s threads", self._plugin_type)
|
logger.debug("Compiling %s threads", self._plugin_type)
|
||||||
name = self.name.replace(" ", "_").lower()
|
name = self.name.replace(" ", "_").lower()
|
||||||
base_name = "{}_{}".format(self._plugin_type, name)
|
base_name = f"{self._plugin_type}_{name}"
|
||||||
self._add_thread("{}_input".format(base_name),
|
self._add_thread(f"{base_name}_input",
|
||||||
self._process_input,
|
self._process_input,
|
||||||
self._queues["in"],
|
self._queues["in"],
|
||||||
self._queues["predict_{}".format(name)])
|
self._queues[f"predict_{name}"])
|
||||||
self._add_thread("{}_predict".format(base_name),
|
self._add_thread(f"{base_name}_predict",
|
||||||
self._predict,
|
self._predict,
|
||||||
self._queues["predict_{}".format(name)],
|
self._queues[f"predict_{name}"],
|
||||||
self._queues["post_{}".format(name)])
|
self._queues[f"post_{name}"])
|
||||||
self._add_thread("{}_output".format(base_name),
|
self._add_thread(f"{base_name}_output",
|
||||||
self._process_output,
|
self._process_output,
|
||||||
self._queues["post_{}".format(name)],
|
self._queues[f"post_{name}"],
|
||||||
self._queues["out"])
|
self._queues["out"])
|
||||||
logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads)
|
logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue