mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Remove custom keras importer * first round keras imports fix * launcher.py: Remove KerasFinder references * 2nd round keras imports update (lib and extract) * 3rd round keras imports update (train) * remove KerasFinder from tests * 4th round keras imports update (tests)
599 lines
21 KiB
Python
599 lines
21 KiB
Python
#!/usr/bin python3
|
|
""" Utilities available across all scripts """
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import urllib
|
|
import warnings
|
|
import zipfile
|
|
|
|
from re import finditer
|
|
from multiprocessing import current_process
|
|
from socket import timeout as socket_timeout, error as socket_error
|
|
|
|
from tqdm import tqdm
|
|
|
|
# Global variables
|
|
_image_extensions = [ # pylint:disable=invalid-name
|
|
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
|
|
_video_extensions = [ # pylint:disable=invalid-name
|
|
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
|
|
".ts", ".vob"]
|
|
_TF_VERS = None
|
|
|
|
|
|
class _Backend(): # pylint:disable=too-few-public-methods
|
|
""" Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment
|
|
Variable.
|
|
|
|
If file doesn't exist and a variable hasn't been set, create the config file. """
|
|
def __init__(self):
|
|
self._backends = {"1": "amd", "2": "cpu", "3": "nvidia"}
|
|
self._config_file = self._get_config_file()
|
|
self.backend = self._get_backend()
|
|
|
|
@classmethod
|
|
def _get_config_file(cls):
|
|
""" Obtain the location of the main Faceswap configuration file.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The path to the Faceswap configuration file
|
|
"""
|
|
pypath = os.path.dirname(os.path.realpath(sys.argv[0]))
|
|
config_file = os.path.join(pypath, "config", ".faceswap")
|
|
return config_file
|
|
|
|
def _get_backend(self):
|
|
""" Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from
|
|
the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user
|
|
to select a backend.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
# Check if environment variable is set, if so use that
|
|
if "FACESWAP_BACKEND" in os.environ:
|
|
fs_backend = os.environ["FACESWAP_BACKEND"].lower()
|
|
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
|
|
return fs_backend
|
|
# Intercept for sphinx docs build
|
|
if sys.argv[0].endswith("sphinx-build"):
|
|
return "nvidia"
|
|
if not os.path.isfile(self._config_file):
|
|
self._configure_backend()
|
|
while True:
|
|
try:
|
|
with open(self._config_file, "r", encoding="utf8") as cnf:
|
|
config = json.load(cnf)
|
|
break
|
|
except json.decoder.JSONDecodeError:
|
|
self._configure_backend()
|
|
continue
|
|
fs_backend = config.get("backend", None)
|
|
if fs_backend is None or fs_backend.lower() not in self._backends.values():
|
|
fs_backend = self._configure_backend()
|
|
if current_process().name == "MainProcess":
|
|
print(f"Setting Faceswap backend to {fs_backend.upper()}")
|
|
return fs_backend.lower()
|
|
|
|
def _configure_backend(self):
|
|
""" Get user input to select the backend that Faceswap should use.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
print("First time configuration. Please select the required backend")
|
|
while True:
|
|
selection = input("1: AMD, 2: CPU, 3: NVIDIA: ")
|
|
if selection not in ("1", "2", "3"):
|
|
print(f"'{selection}' is not a valid selection. Please try again")
|
|
continue
|
|
break
|
|
fs_backend = self._backends[selection].lower()
|
|
config = {"backend": fs_backend}
|
|
with open(self._config_file, "w", encoding="utf8") as cnf:
|
|
json.dump(config, cnf)
|
|
print(f"Faceswap config written to: {self._config_file}")
|
|
return fs_backend
|
|
|
|
|
|
_FS_BACKEND = _Backend().backend
|
|
|
|
|
|
def get_backend():
|
|
""" Get the backend that Faceswap is currently configured to use.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
return _FS_BACKEND
|
|
|
|
|
|
def set_backend(backend):
|
|
""" Override the configured backend with the given backend.
|
|
|
|
Parameters
|
|
----------
|
|
backend: ["amd", "cpu", "nvidia"]
|
|
The backend to set faceswap to
|
|
"""
|
|
global _FS_BACKEND # pylint:disable=global-statement
|
|
_FS_BACKEND = backend.lower()
|
|
|
|
|
|
def get_tf_version():
|
|
""" Obtain the major.minor version of currently installed Tensorflow.
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
The currently installed tensorflow version
|
|
"""
|
|
global _TF_VERS # pylint:disable=global-statement
|
|
if _TF_VERS is None:
|
|
import tensorflow as tf # pylint:disable=import-outside-toplevel
|
|
_TF_VERS = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
|
|
return _TF_VERS
|
|
|
|
|
|
def get_folder(path, make_folder=True):
|
|
""" Return a path to a folder, creating it if it doesn't exist
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
The path to the folder to obtain
|
|
make_folder: bool, optional
|
|
``True`` if the folder should be created if it does not already exist, ``False`` if the
|
|
folder should not be created
|
|
|
|
Returns
|
|
-------
|
|
str or `None`
|
|
The path to the requested folder. If `make_folder` is set to ``False`` and the requested
|
|
path does not exist, then ``None`` is returned
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("Requested path: '%s'", path)
|
|
if not make_folder and not os.path.isdir(path):
|
|
logger.debug("%s does not exist", path)
|
|
return None
|
|
os.makedirs(path, exist_ok=True)
|
|
logger.debug("Returning: '%s'", path)
|
|
return path
|
|
|
|
|
|
def get_image_paths(directory, extension=None):
|
|
""" Obtain a list of full paths that reside within a folder.
|
|
|
|
Parameters
|
|
----------
|
|
directory: str
|
|
The folder that contains the images to be returned
|
|
extension: str
|
|
The specific image extensions that should be returned
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The list of full paths to the images contained within the given folder
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
image_extensions = _image_extensions if extension is None else [extension]
|
|
dir_contents = []
|
|
|
|
if not os.path.exists(directory):
|
|
logger.debug("Creating folder: '%s'", directory)
|
|
directory = get_folder(directory)
|
|
|
|
dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name)
|
|
logger.debug("Scanned Folder contains %s files", len(dir_scanned))
|
|
logger.trace("Scanned Folder Contents: %s", dir_scanned)
|
|
|
|
for chkfile in dir_scanned:
|
|
if any(chkfile.name.lower().endswith(ext) for ext in image_extensions):
|
|
logger.trace("Adding '%s' to image list", chkfile.path)
|
|
dir_contents.append(chkfile.path)
|
|
|
|
logger.debug("Returning %s images", len(dir_contents))
|
|
return dir_contents
|
|
|
|
|
|
def convert_to_secs(*args):
|
|
""" Convert a time to seconds.
|
|
|
|
Parameters
|
|
----------
|
|
args: tuple
|
|
2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints are
|
|
supplied then (`hours`, `minutes`, `seconds`) is implied.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
The given time converted to seconds
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("from time: %s", args)
|
|
retval = 0.0
|
|
if len(args) == 1:
|
|
retval = float(args[0])
|
|
elif len(args) == 2:
|
|
retval = 60 * float(args[0]) + float(args[1])
|
|
elif len(args) == 3:
|
|
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
|
|
logger.debug("to secs: %s", retval)
|
|
return retval
|
|
|
|
|
|
def full_path_split(path):
|
|
""" Split a full path to a location into all of it's separate components.
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
The full path to be split
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The full path split into a separate item for each part
|
|
|
|
Example
|
|
-------
|
|
>>> path = "/foo/baz/bar"
|
|
>>> full_path_split(path)
|
|
>>> ["foo", "baz", "bar"]
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
allparts = []
|
|
while True:
|
|
parts = os.path.split(path)
|
|
if parts[0] == path: # sentinel for absolute paths
|
|
allparts.insert(0, parts[0])
|
|
break
|
|
if parts[1] == path: # sentinel for relative paths
|
|
allparts.insert(0, parts[1])
|
|
break
|
|
path = parts[0]
|
|
allparts.insert(0, parts[1])
|
|
logger.trace("path: %s, allparts: %s", path, allparts)
|
|
return allparts
|
|
|
|
|
|
def set_system_verbosity(log_level):
|
|
""" Set the verbosity level of tensorflow and suppresses future and deprecation warnings from
|
|
any modules
|
|
|
|
Parameters
|
|
----------
|
|
log_level: str
|
|
The requested Faceswap log level
|
|
|
|
References
|
|
----------
|
|
https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
|
|
Can be set to:
|
|
0: all logs shown. 1: filter out INFO logs. 2: filter out WARNING logs. 3: filter out ERROR
|
|
logs.
|
|
"""
|
|
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
from lib.logger import get_loglevel # pylint:disable=import-outside-toplevel
|
|
numeric_level = get_loglevel(log_level)
|
|
log_level = "3" if numeric_level > 15 else "0"
|
|
logger.debug("System Verbosity level: %s", log_level)
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = log_level
|
|
if log_level != '0':
|
|
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
|
|
warnings.simplefilter(action='ignore', category=warncat)
|
|
|
|
|
|
def deprecation_warning(function, additional_info=None):
|
|
""" Log at warning level that a function will be removed in a future update.
|
|
|
|
Parameters
|
|
----------
|
|
function: str
|
|
The function that will be deprecated.
|
|
additional_info: str, optional
|
|
Any additional information to display with the deprecation message. Default: ``None``
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
|
|
msg = f"{function} has been deprecated and will be removed from a future update."
|
|
if additional_info is not None:
|
|
msg += f" {additional_info}"
|
|
logger.warning(msg)
|
|
|
|
|
|
def camel_case_split(identifier):
|
|
""" Split a camel case name
|
|
|
|
Parameters
|
|
----------
|
|
identifier: str
|
|
The camel case text to be split
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
A list of the given identifier split into it's constituent parts
|
|
|
|
|
|
References
|
|
----------
|
|
https://stackoverflow.com/questions/29916065
|
|
"""
|
|
matches = finditer(
|
|
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)",
|
|
identifier)
|
|
return [m.group(0) for m in matches]
|
|
|
|
|
|
def safe_shutdown(got_error=False):
|
|
""" Close all tracked queues and threads in event of crash or on shut down.
|
|
|
|
Parameters
|
|
----------
|
|
got_error: bool, optional
|
|
``True`` if this function is being called as the result of raised error, otherwise
|
|
``False``. Default: ``False``
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("Safely shutting down")
|
|
from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel
|
|
queue_manager.terminate_queues()
|
|
logger.debug("Cleanup complete. Shutting down queue manager and exiting")
|
|
sys.exit(1 if got_error else 0)
|
|
|
|
|
|
class FaceswapError(Exception):
|
|
""" Faceswap Error for handling specific errors with useful information.
|
|
|
|
Raises
|
|
------
|
|
FaceswapError
|
|
on a captured error
|
|
"""
|
|
pass # pylint:disable=unnecessary-pass
|
|
|
|
|
|
class GetModel(): # pylint:disable=too-few-public-methods
|
|
""" Check for models in their cache path.
|
|
|
|
If available, return the path, if not available, get, unzip and install model
|
|
|
|
Parameters
|
|
----------
|
|
model_filename: str or list
|
|
The name of the model to be loaded (see notes below)
|
|
cache_dir: str
|
|
The model cache folder of the current plugin calling this class. IE: The folder that holds
|
|
the model to be loaded.
|
|
git_model_id: int
|
|
The second digit in the github tag that identifies this model. See
|
|
https://github.com/deepfakes-models/faceswap-models for more information
|
|
|
|
Notes
|
|
------
|
|
Models must have a certain naming convention: `<model_name>_v<version_number>.<extension>`
|
|
(eg: `s3fd_v1.pb`).
|
|
|
|
Multiple models can exist within the model_filename. They should be passed as a list and follow
|
|
the same naming convention as above. Any differences in filename should occur AFTER the version
|
|
number: `<model_name>_v<version_number><differentiating_information>.<extension>` (eg:
|
|
`["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel"
|
|
,"resnet_ssd_v1.prototext"]`
|
|
"""
|
|
|
|
def __init__(self, model_filename, cache_dir, git_model_id):
|
|
self.logger = logging.getLogger(__name__)
|
|
if not isinstance(model_filename, list):
|
|
model_filename = [model_filename]
|
|
self._model_filename = model_filename
|
|
self._cache_dir = cache_dir
|
|
self._git_model_id = git_model_id
|
|
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
|
|
self._chunk_size = 1024 # Chunk size for downloading and unzipping
|
|
self._retries = 6
|
|
self._get()
|
|
|
|
@property
|
|
def _model_full_name(self):
|
|
""" str: The full model name from the filename(s). """
|
|
common_prefix = os.path.commonprefix(self._model_filename)
|
|
retval = os.path.splitext(common_prefix)[0]
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _model_name(self):
|
|
""" str: The model name from the model's full name. """
|
|
retval = self._model_full_name[:self._model_full_name.rfind("_")]
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _model_version(self):
|
|
""" int: The model's version number from the model full name. """
|
|
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def model_path(self):
|
|
""" str: The model path(s) in the cache folder. """
|
|
retval = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
|
|
retval = retval[0] if len(retval) == 1 else retval
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _model_zip_path(self):
|
|
""" str: The full path to downloaded zip file. """
|
|
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _model_exists(self):
|
|
""" bool: ``True`` if the model exists in the cache folder otherwise ``False``. """
|
|
if isinstance(self.model_path, list):
|
|
retval = all(os.path.exists(pth) for pth in self.model_path)
|
|
else:
|
|
retval = os.path.exists(self.model_path)
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _plugin_section(self):
|
|
""" str: The plugin section from the config_dir """
|
|
path = os.path.normpath(self._cache_dir)
|
|
split = path.split(os.sep)
|
|
retval = split[split.index("plugins") + 1]
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _url_section(self):
|
|
""" int: The section ID in github for this plugin type. """
|
|
sections = dict(extract=1, train=2, convert=3)
|
|
retval = sections[self._plugin_section]
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
@property
|
|
def _url_download(self):
|
|
""" strL Base download URL for models. """
|
|
tag = f"v{self._url_section}.{self._git_model_id}.{self._model_version}"
|
|
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
|
|
self.logger.trace("Download url: %s", retval)
|
|
return retval
|
|
|
|
@property
|
|
def _url_partial_size(self):
|
|
""" float: How many bytes have already been downloaded. """
|
|
zip_file = self._model_zip_path
|
|
retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0
|
|
self.logger.trace(retval)
|
|
return retval
|
|
|
|
def _get(self):
|
|
""" Check the model exists, if not, download the model, unzip it and place it in the
|
|
model's cache folder. """
|
|
if self._model_exists:
|
|
self.logger.debug("Model exists: %s", self.model_path)
|
|
return
|
|
self._download_model()
|
|
self._unzip_model()
|
|
os.remove(self._model_zip_path)
|
|
|
|
def _download_model(self):
|
|
""" Download the model zip from github to the cache folder. """
|
|
self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download)
|
|
for attempt in range(self._retries):
|
|
try:
|
|
downloaded_size = self._url_partial_size
|
|
req = urllib.request.Request(self._url_download)
|
|
if downloaded_size != 0:
|
|
req.add_header("Range", f"bytes={downloaded_size}-")
|
|
with urllib.request.urlopen(req, timeout=10) as response:
|
|
self.logger.debug("header info: {%s}", response.info())
|
|
self.logger.debug("Return Code: %s", response.getcode())
|
|
self._write_zipfile(response, downloaded_size)
|
|
break
|
|
except (socket_error, socket_timeout,
|
|
urllib.error.HTTPError, urllib.error.URLError) as err:
|
|
if attempt + 1 < self._retries:
|
|
self.logger.warning("Error downloading model (%s). Retrying %s of %s...",
|
|
str(err), attempt + 2, self._retries)
|
|
else:
|
|
self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: "
|
|
"'%s')", str(err), self._url_download)
|
|
self.logger.info("You can try running again to resume the download.")
|
|
self.logger.info("Alternatively, you can manually download the model from: %s "
|
|
"and unzip the contents to: %s",
|
|
self._url_download, self._cache_dir)
|
|
sys.exit(1)
|
|
|
|
def _write_zipfile(self, response, downloaded_size):
|
|
""" Write the model zip file to disk.
|
|
|
|
Parameters
|
|
----------
|
|
response: :class:`urllib.request.urlopen`
|
|
The response from the model download task
|
|
downloaded_size: int
|
|
The amount of bytes downloaded so far
|
|
"""
|
|
length = int(response.getheader("content-length")) + downloaded_size
|
|
if length == downloaded_size:
|
|
self.logger.info("Zip already exists. Skipping download")
|
|
return
|
|
write_type = "wb" if downloaded_size == 0 else "ab"
|
|
with open(self._model_zip_path, write_type) as out_file:
|
|
pbar = tqdm(desc="Downloading",
|
|
unit="B",
|
|
total=length,
|
|
unit_scale=True,
|
|
unit_divisor=1024)
|
|
if downloaded_size != 0:
|
|
pbar.update(downloaded_size)
|
|
while True:
|
|
buffer = response.read(self._chunk_size)
|
|
if not buffer:
|
|
break
|
|
pbar.update(len(buffer))
|
|
out_file.write(buffer)
|
|
pbar.close()
|
|
|
|
def _unzip_model(self):
|
|
""" Unzip the model file to the cache folder """
|
|
self.logger.info("Extracting: '%s'", self._model_name)
|
|
try:
|
|
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
|
|
self._write_model(zip_file)
|
|
except Exception as err: # pylint:disable=broad-except
|
|
self.logger.error("Unable to extract model file: %s", str(err))
|
|
sys.exit(1)
|
|
|
|
def _write_model(self, zip_file):
|
|
""" Extract files from zip file and write, with progress bar.
|
|
|
|
Parameters
|
|
----------
|
|
zip_file: str
|
|
The downloaded model zip file
|
|
"""
|
|
length = sum(f.file_size for f in zip_file.infolist())
|
|
fnames = zip_file.namelist()
|
|
self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length)
|
|
pbar = tqdm(desc="Decompressing",
|
|
unit="B",
|
|
total=length,
|
|
unit_scale=True,
|
|
unit_divisor=1024)
|
|
for fname in fnames:
|
|
out_fname = os.path.join(self._cache_dir, fname)
|
|
self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname)
|
|
zipped = zip_file.open(fname)
|
|
with open(out_fname, "wb") as out_file:
|
|
while True:
|
|
buffer = zipped.read(self._chunk_size)
|
|
if not buffer:
|
|
break
|
|
pbar.update(len(buffer))
|
|
out_file.write(buffer)
|
|
zip_file.close()
|
|
pbar.close()
|