diff --git a/plugins/extract/align/.cache/.keep b/.fs_cache/.keep similarity index 100% rename from plugins/extract/align/.cache/.keep rename to .fs_cache/.keep diff --git a/.gitignore b/.gitignore index fbf21eba..6b5149d5 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ !tests/**/*.py # Core files +!.fs_cache !lib/ !lib/**/ !lib/**/*.py diff --git a/lib/utils.py b/lib/utils.py index 7f1f4745..143f5f25 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -370,7 +370,7 @@ class FaceswapError(Exception): class GetModel(): # pylint:disable=too-few-public-methods - """ Check for models in their cache path. + """ Check for models in the cache path. If available, return the path, if not available, get, unzip and install model @@ -378,9 +378,6 @@ class GetModel(): # pylint:disable=too-few-public-methods ---------- model_filename: str or list The name of the model to be loaded (see notes below) - cache_dir: str - The model cache folder of the current plugin calling this class. IE: The folder that holds - the model to be loaded. git_model_id: int The second digit in the github tag that identifies this model. See https://github.com/deepfakes-models/faceswap-models for more information @@ -397,12 +394,12 @@ class GetModel(): # pylint:disable=too-few-public-methods ,"resnet_ssd_v1.prototext"]` """ - def __init__(self, model_filename, cache_dir, git_model_id): + def __init__(self, model_filename, git_model_id): self.logger = logging.getLogger(__name__) if not isinstance(model_filename, list): model_filename = [model_filename] self._model_filename = model_filename - self._cache_dir = cache_dir + self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache") self._git_model_id = git_model_id self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download" self._chunk_size = 1024 # Chunk size for downloading and unzipping @@ -456,27 +453,10 @@ class GetModel(): # pylint:disable=too-few-public-methods self.logger.trace(retval) return retval - @property - def _plugin_section(self): - """ str: The plugin section from the config_dir """ - path = os.path.normpath(self._cache_dir) - split = path.split(os.sep) - retval = split[split.index("plugins") + 1] - self.logger.trace(retval) - return retval - - @property - def _url_section(self): - """ int: The section ID in github for this plugin type. """ - sections = dict(extract=1, train=2, convert=3) - retval = sections[self._plugin_section] - self.logger.trace(retval) - return retval - @property def _url_download(self): """ strL Base download URL for models. """ - tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}" + tag = f"v{self._git_model_id}.{self._model_version}" retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip" self.logger.trace("Download url: %s", retval) return retval diff --git a/lib/vgg_face.py b/lib/vgg_face.py index be917dba..9917fbac 100644 --- a/lib/vgg_face.py +++ b/lib/vgg_face.py @@ -7,8 +7,6 @@ https://creativecommons.org/licenses/by-nc/4.0/ """ import logging -import sys -import os import cv2 import numpy as np @@ -37,9 +35,7 @@ class VGGFace(): # <<< GET MODEL >>> # def get_model(self, git_model_id, model_filename, backend): """ Check if model is available, if not, download and unzip it """ - root_path = os.path.abspath(os.path.dirname(sys.argv[0])) - cache_path = os.path.join(root_path, "plugins", "extract", "recognition", ".cache") - model = GetModel(model_filename, cache_path, git_model_id).model_path + model = GetModel(model_filename, git_model_id).model_path model = cv2.dnn.readNetFromCaffe(model[1], model[0]) model.setPreferableTarget(self.get_backend(backend)) return model @@ -50,7 +46,7 @@ class VGGFace(): if backend == "OPENCL": logger.info("Using OpenCL backend. If the process runs, you can safely ignore any of " "the failure messages.") - retval = getattr(cv2.dnn, "DNN_TARGET_{}".format(backend)) + retval = getattr(cv2.dnn, f"DNN_TARGET_{backend}") return retval def predict(self, face): diff --git a/plugins/extract/_base.py b/plugins/extract/_base.py index 51c84639..0d27e371 100644 --- a/plugins/extract/_base.py +++ b/plugins/extract/_base.py @@ -3,8 +3,6 @@ :mod:`~plugins.extract.mask` Plugins """ import logging -import os -import sys from tensorflow.python.framework import errors_impl as tf_errors @@ -139,13 +137,13 @@ class Extractor(): """ int: Batchsize for feeding this model. The number of images the model should feed through at once. """ - self._queues = dict() + self._queues = {} """ dict: in + out queues and internal queues for this plugin, """ self._threads = [] """ list: Internal threads for this plugin """ - self._extract_media = dict() + self._extract_media = {} """ dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being processed. Stored at input for pairing back up on output of extractor process """ @@ -352,7 +350,8 @@ class Extractor(): # <<< PROTECTED ACCESS METHODS >>> # # <<< INIT METHODS >>> # - def _get_model(self, git_model_id, model_filename): + @classmethod + def _get_model(cls, git_model_id, model_filename): """ Check if model is available, if not, download and unzip it """ if model_filename is None: logger.debug("No model_filename specified. Returning None") @@ -360,13 +359,7 @@ class Extractor(): if git_model_id is None: logger.debug("No git_model_id specified. Returning None") return None - plugin_path = os.path.join(*self.__module__.split(".")[:-1]) - if os.path.basename(plugin_path) in ("detect", "align", "mask", "recognition"): - base_path = os.path.dirname(os.path.realpath(sys.argv[0])) - cache_path = os.path.join(base_path, plugin_path, ".cache") - else: - cache_path = os.path.join(os.path.dirname(__file__), ".cache") - model = GetModel(model_filename, cache_path, git_model_id) + model = GetModel(model_filename, git_model_id) return model.model_path # <<< PLUGIN INITIALIZATION >>> # @@ -382,7 +375,7 @@ class Extractor(): name = self.name.replace(" ", "_").lower() self._add_queues(kwargs["in_queue"], kwargs["out_queue"], - ["predict_{}".format(name), "post_{}".format(name)]) + [f"predict_{name}", f"post_{name}"]) self._compile_threads() try: self.init_model() @@ -409,7 +402,7 @@ class Extractor(): self._queues["out"] = out_queue for q_name in queues: self._queues[q_name] = queue_manager.get_queue( - name="{}{}_{}".format(self._plugin_type, self._instance, q_name), + name=f"{self._plugin_type}{self._instance}_{q_name}", maxsize=self.queue_size) # <<< THREAD METHODS >>> # @@ -417,18 +410,18 @@ class Extractor(): """ Compile the threads into self._threads list """ logger.debug("Compiling %s threads", self._plugin_type) name = self.name.replace(" ", "_").lower() - base_name = "{}_{}".format(self._plugin_type, name) - self._add_thread("{}_input".format(base_name), + base_name = f"{self._plugin_type}_{name}" + self._add_thread(f"{base_name}_input", self._process_input, self._queues["in"], - self._queues["predict_{}".format(name)]) - self._add_thread("{}_predict".format(base_name), + self._queues[f"predict_{name}"]) + self._add_thread(f"{base_name}_predict", self._predict, - self._queues["predict_{}".format(name)], - self._queues["post_{}".format(name)]) - self._add_thread("{}_output".format(base_name), + self._queues[f"predict_{name}"], + self._queues[f"post_{name}"]) + self._add_thread(f"{base_name}_output", self._process_output, - self._queues["post_{}".format(name)], + self._queues[f"post_{name}"], self._queues["out"]) logger.debug("Compiled %s threads: %s", self._plugin_type, self._threads) diff --git a/plugins/extract/detect/.cache/.keep b/plugins/extract/detect/.cache/.keep deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/extract/mask/.cache/.keep b/plugins/extract/mask/.cache/.keep deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/extract/recognition/.cache/.keep b/plugins/extract/recognition/.cache/.keep deleted file mode 100644 index e69de29b..00000000