1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 11:53:26 -04:00
faceswap/lib/utils.py

705 lines
25 KiB
Python

#!/usr/bin python3
""" Utilities available across all scripts """
import json
import logging
import os
import sys
import tkinter as tk
import urllib
import warnings
import zipfile
from re import finditer
from multiprocessing import current_process
from socket import timeout as socket_timeout, error as socket_error
from time import time
from typing import cast, List, Optional, Union, TYPE_CHECKING
import numpy as np
from tqdm import tqdm
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from http.client import HTTPResponse
# Global variables
_image_extensions = [ # pylint:disable=invalid-name
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
_video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"]
_TF_VERS = None
ValidBackends = Literal["amd", "nvidia", "cpu", "apple_silicon"]
class _Backend(): # pylint:disable=too-few-public-methods
""" Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment
Variable.
If file doesn't exist and a variable hasn't been set, create the config file. """
def __init__(self) -> None:
self._backends = {"1": "amd", "2": "cpu", "3": "nvidia", "4": "apple_silicon"}
self._valid_backends = list(self._backends.values())
self._config_file = self._get_config_file()
self.backend = self._get_backend()
@classmethod
def _get_config_file(cls) -> str:
""" Obtain the location of the main Faceswap configuration file.
Returns
-------
str
The path to the Faceswap configuration file
"""
pypath = os.path.dirname(os.path.realpath(sys.argv[0]))
config_file = os.path.join(pypath, "config", ".faceswap")
return config_file
def _get_backend(self) -> ValidBackends:
""" Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from
the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user
to select a backend.
Returns
-------
str
The backend configuration in use by Faceswap
"""
# Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ:
fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
assert fs_backend in get_args(ValidBackends), (
f"Faceswap backend must be one of {get_args(ValidBackends)}")
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
return fs_backend
# Intercept for sphinx docs build
if sys.argv[0].endswith("sphinx-build"):
return "nvidia"
if not os.path.isfile(self._config_file):
self._configure_backend()
while True:
try:
with open(self._config_file, "r", encoding="utf8") as cnf:
config = json.load(cnf)
break
except json.decoder.JSONDecodeError:
self._configure_backend()
continue
fs_backend = config.get("backend", "").lower()
if not fs_backend or fs_backend not in self._backends.values():
fs_backend = self._configure_backend()
if current_process().name == "MainProcess":
print(f"Setting Faceswap backend to {fs_backend.upper()}")
return fs_backend
def _configure_backend(self) -> ValidBackends:
""" Get user input to select the backend that Faceswap should use.
Returns
-------
str
The backend configuration in use by Faceswap
"""
print("First time configuration. Please select the required backend")
while True:
selection = input("1: AMD, 2: CPU, 3: NVIDIA, 4: APPLE SILICON: ")
if selection not in ("1", "2", "3", "4"):
print(f"'{selection}' is not a valid selection. Please try again")
continue
break
fs_backend = cast(ValidBackends, self._backends[selection].lower())
config = {"backend": fs_backend}
with open(self._config_file, "w", encoding="utf8") as cnf:
json.dump(config, cnf)
print(f"Faceswap config written to: {self._config_file}")
return fs_backend
_FS_BACKEND: ValidBackends = _Backend().backend
def get_backend() -> ValidBackends:
""" Get the backend that Faceswap is currently configured to use.
Returns
-------
str
The backend configuration in use by Faceswap
"""
return _FS_BACKEND
def set_backend(backend: str) -> None:
""" Override the configured backend with the given backend.
Parameters
----------
backend: ["amd", "cpu", "nvidia", "apple_silicon"]
The backend to set faceswap to
"""
global _FS_BACKEND # pylint:disable=global-statement
backend = cast(ValidBackends, backend.lower())
_FS_BACKEND = backend
def get_tf_version() -> float:
""" Obtain the major.minor version of currently installed Tensorflow.
Returns
-------
float
The currently installed tensorflow version
"""
global _TF_VERS # pylint:disable=global-statement
if _TF_VERS is None:
import tensorflow as tf # pylint:disable=import-outside-toplevel
_TF_VERS = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
return _TF_VERS
def get_folder(path: str, make_folder: bool = True) -> str:
""" Return a path to a folder, creating it if it doesn't exist
Parameters
----------
path: str
The path to the folder to obtain
make_folder: bool, optional
``True`` if the folder should be created if it does not already exist, ``False`` if the
folder should not be created
Returns
-------
str or `None`
The path to the requested folder. If `make_folder` is set to ``False`` and the requested
path does not exist, then ``None`` is returned
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("Requested path: '%s'", path)
if not make_folder and not os.path.isdir(path):
logger.debug("%s does not exist", path)
return ""
os.makedirs(path, exist_ok=True)
logger.debug("Returning: '%s'", path)
return path
def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]:
""" Obtain a list of full paths that reside within a folder.
Parameters
----------
directory: str
The folder that contains the images to be returned
extension: str
The specific image extensions that should be returned
Returns
-------
list
The list of full paths to the images contained within the given folder
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
image_extensions = _image_extensions if extension is None else [extension]
dir_contents = []
if not os.path.exists(directory):
logger.debug("Creating folder: '%s'", directory)
directory = get_folder(directory)
dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name)
logger.debug("Scanned Folder contains %s files", len(dir_scanned))
logger.trace("Scanned Folder Contents: %s", dir_scanned) # type:ignore
for chkfile in dir_scanned:
if any(chkfile.name.lower().endswith(ext) for ext in image_extensions):
logger.trace("Adding '%s' to image list", chkfile.path) # type:ignore
dir_contents.append(chkfile.path)
logger.debug("Returning %s images", len(dir_contents))
return dir_contents
def get_dpi() -> float:
""" Obtain the DPI of the running screen.
Returns
-------
int
The obtain dots per inch of the running monitor
"""
root = tk.Tk()
dpi = root.winfo_fpixels('1i')
return float(dpi)
def convert_to_secs(*args: int) -> int:
""" Convert a time to seconds.
Parameters
----------
args: tuple
2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints are
supplied then (`hours`, `minutes`, `seconds`) is implied.
Returns
-------
int
The given time converted to seconds
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("from time: %s", args)
retval = 0.0
if len(args) == 1:
retval = float(args[0])
elif len(args) == 2:
retval = 60 * float(args[0]) + float(args[1])
elif len(args) == 3:
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
retval = int(retval)
logger.debug("to secs: %s", retval)
return retval
def full_path_split(path: str) -> List[str]:
""" Split a full path to a location into all of it's separate components.
Parameters
----------
path: str
The full path to be split
Returns
-------
list
The full path split into a separate item for each part
Example
-------
>>> path = "/foo/baz/bar"
>>> full_path_split(path)
>>> ["foo", "baz", "bar"]
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
allparts: List[str] = []
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
allparts.insert(0, parts[0])
break
if parts[1] == path: # sentinel for relative paths
allparts.insert(0, parts[1])
break
path = parts[0]
allparts.insert(0, parts[1])
logger.trace("path: %s, allparts: %s", path, allparts) # type:ignore
return allparts
def set_system_verbosity(log_level: str):
""" Set the verbosity level of tensorflow and suppresses future and deprecation warnings from
any modules
Parameters
----------
log_level: str
The requested Faceswap log level
References
----------
https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
Can be set to:
0: all logs shown. 1: filter out INFO logs. 2: filter out WARNING logs. 3: filter out ERROR
logs.
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
from lib.logger import get_loglevel # pylint:disable=import-outside-toplevel
numeric_level = get_loglevel(log_level)
log_level = "3" if numeric_level > 15 else "0"
logger.debug("System Verbosity level: %s", log_level)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = log_level
if log_level != '0':
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.simplefilter(action='ignore', category=warncat)
def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None:
""" Log at warning level that a function will be removed in a future update.
Parameters
----------
function: str
The function that will be deprecated.
additional_info: str, optional
Any additional information to display with the deprecation message. Default: ``None``
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
msg = f"{function} has been deprecated and will be removed from a future update."
if additional_info is not None:
msg += f" {additional_info}"
logger.warning(msg)
def camel_case_split(identifier: str) -> List[str]:
""" Split a camel case name
Parameters
----------
identifier: str
The camel case text to be split
Returns
-------
list
A list of the given identifier split into it's constituent parts
References
----------
https://stackoverflow.com/questions/29916065
"""
matches = finditer(
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)",
identifier)
return [m.group(0) for m in matches]
def safe_shutdown(got_error: bool = False) -> None:
""" Close all tracked queues and threads in event of crash or on shut down.
Parameters
----------
got_error: bool, optional
``True`` if this function is being called as the result of raised error, otherwise
``False``. Default: ``False``
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("Safely shutting down")
from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel
queue_manager.terminate_queues()
logger.debug("Cleanup complete. Shutting down queue manager and exiting")
sys.exit(1 if got_error else 0)
class FaceswapError(Exception):
""" Faceswap Error for handling specific errors with useful information.
Raises
------
FaceswapError
on a captured error
"""
pass # pylint:disable=unnecessary-pass
class GetModel(): # pylint:disable=too-few-public-methods
""" Check for models in the cache path.
If available, return the path, if not available, get, unzip and install model
Parameters
----------
model_filename: str or list
The name of the model to be loaded (see notes below)
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
Notes
------
Models must have a certain naming convention: `<model_name>_v<version_number>.<extension>`
(eg: `s3fd_v1.pb`).
Multiple models can exist within the model_filename. They should be passed as a list and follow
the same naming convention as above. Any differences in filename should occur AFTER the version
number: `<model_name>_v<version_number><differentiating_information>.<extension>` (eg:
`["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel"
,"resnet_ssd_v1.prototext"]`
"""
def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None:
self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list):
model_filename = [model_filename]
self._model_filename = model_filename
self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
self._git_model_id = git_model_id
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
self._chunk_size = 1024 # Chunk size for downloading and unzipping
self._retries = 6
self._get()
@property
def _model_full_name(self) -> str:
""" str: The full model name from the filename(s). """
common_prefix = os.path.commonprefix(self._model_filename)
retval = os.path.splitext(common_prefix)[0]
self.logger.trace(retval) # type: ignore
return retval
@property
def _model_name(self) -> str:
""" str: The model name from the model's full name. """
retval = self._model_full_name[:self._model_full_name.rfind("_")]
self.logger.trace(retval) # type: ignore
return retval
@property
def _model_version(self) -> int:
""" int: The model's version number from the model full name. """
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
self.logger.trace(retval) # type: ignore
return retval
@property
def model_path(self) -> Union[str, List[str]]:
""" str or list: The model path(s) in the cache folder. """
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths
self.logger.trace(retval) # type: ignore
return retval
@property
def _model_zip_path(self) -> str:
""" str: The full path to downloaded zip file. """
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
self.logger.trace(retval) # type: ignore
return retval
@property
def _model_exists(self) -> bool:
""" bool: ``True`` if the model exists in the cache folder otherwise ``False``. """
if isinstance(self.model_path, list):
retval = all(os.path.exists(pth) for pth in self.model_path)
else:
retval = os.path.exists(self.model_path)
self.logger.trace(retval) # type: ignore
return retval
@property
def _url_download(self) -> str:
""" strL Base download URL for models. """
tag = f"v{self._git_model_id}.{self._model_version}"
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
self.logger.trace("Download url: %s", retval) # type: ignore
return retval
@property
def _url_partial_size(self) -> int:
""" int: How many bytes have already been downloaded. """
zip_file = self._model_zip_path
retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0
self.logger.trace(retval) # type: ignore
return retval
def _get(self) -> None:
""" Check the model exists, if not, download the model, unzip it and place it in the
model's cache folder. """
if self._model_exists:
self.logger.debug("Model exists: %s", self.model_path)
return
self._download_model()
self._unzip_model()
os.remove(self._model_zip_path)
def _download_model(self) -> None:
""" Download the model zip from github to the cache folder. """
self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download)
for attempt in range(self._retries):
try:
downloaded_size = self._url_partial_size
req = urllib.request.Request(self._url_download)
if downloaded_size != 0:
req.add_header("Range", f"bytes={downloaded_size}-")
with urllib.request.urlopen(req, timeout=10) as response:
self.logger.debug("header info: {%s}", response.info())
self.logger.debug("Return Code: %s", response.getcode())
self._write_zipfile(response, downloaded_size)
break
except (socket_error, socket_timeout,
urllib.error.HTTPError, urllib.error.URLError) as err:
if attempt + 1 < self._retries:
self.logger.warning("Error downloading model (%s). Retrying %s of %s...",
str(err), attempt + 2, self._retries)
else:
self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: "
"'%s')", str(err), self._url_download)
self.logger.info("You can try running again to resume the download.")
self.logger.info("Alternatively, you can manually download the model from: %s "
"and unzip the contents to: %s",
self._url_download, self._cache_dir)
sys.exit(1)
def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None:
""" Write the model zip file to disk.
Parameters
----------
response: :class:`http.client.HTTPResponse`
The response from the model download task
downloaded_size: int
The amount of bytes downloaded so far
"""
content_length = response.getheader("content-length")
content_length = "0" if content_length is None else content_length
length = int(content_length) + downloaded_size
if length == downloaded_size:
self.logger.info("Zip already exists. Skipping download")
return
write_type = "wb" if downloaded_size == 0 else "ab"
with open(self._model_zip_path, write_type) as out_file:
pbar = tqdm(desc="Downloading",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
if downloaded_size != 0:
pbar.update(downloaded_size)
while True:
buffer = response.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
pbar.close()
def _unzip_model(self) -> None:
""" Unzip the model file to the cache folder """
self.logger.info("Extracting: '%s'", self._model_name)
try:
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
self._write_model(zip_file)
except Exception as err: # pylint:disable=broad-except
self.logger.error("Unable to extract model file: %s", str(err))
sys.exit(1)
def _write_model(self, zip_file: zipfile.ZipFile) -> None:
""" Extract files from zip file and write, with progress bar.
Parameters
----------
zip_file: :class:`zipfile.ZipFile`
The downloaded model zip file
"""
length = sum(f.file_size for f in zip_file.infolist())
fnames = zip_file.namelist()
self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length)
pbar = tqdm(desc="Decompressing",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
for fname in fnames:
out_fname = os.path.join(self._cache_dir, fname)
self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname)
zipped = zip_file.open(fname)
with open(out_fname, "wb") as out_file:
while True:
buffer = zipped.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
zip_file.close()
pbar.close()
class DebugTimes():
""" A simple tool to help debug timings.
"""
def __init__(self):
self._times = {}
self._steps = {}
self._interval = 1
def step_start(self, name: str, record: bool = True) -> None:
""" Start the timer for the given step name.
Parameters
----------
name: str
The name of the step to start the timer for
record: bool, optional
``True`` to record the step time, ``False`` to not record it.
Used for when you have conditional code to time, but do not want to insert if/else
statements in the code. Default: `True`
"""
if not record:
return
self._steps[name] = time()
def step_end(self, name: str, record: bool = True) -> None:
""" Stop the timer and record elapsed time for the given step name.
Parameters
----------
name: str
The name of the step to end the timer for
record: bool, optional
``True`` to record the step time, ``False`` to not record it.
Used for when you have conditional code to time, but do not want to insert if/else
statements in the code. Default: `True`
"""
if not record:
return
self._times.setdefault(name, []).append(time() - self._steps.pop(name))
@classmethod
def _format_column(cls, text: str, width: int) -> str:
""" Pad the given text to be aligned to the given width.
Parameters
----------
text: str
The text to be formatted
width: int
The size of the column to insert the text into
Returns
-------
str
The text with the correct amount of padding applied
"""
return f"{text}{' ' * (width - len(text))}"
def summary(self, decimal_places: int = 6, interval: int = 1) -> None:
""" Output a summary of step times.
Parameters
----------
decimal_places: int, optional
The number of decimal places to display the summary elapsed times to
interval: int, optional
How many times summary must be called before printing to console. Default: 1
"""
interval = max(1, interval)
if interval != self._interval:
self._interval += 1
return
name_col = max(len(key) for key in self._times) + 4
items_col = 8
time_col = decimal_places + 4
print("")
print("-" * (name_col + items_col + (3 * time_col)))
print(f"{self._format_column('Step', name_col)}{self._format_column('Count', items_col)}"
f"{self._format_column('Min', time_col)}{self._format_column('Avg', time_col)}"
f"{self._format_column('Max', time_col)}")
print("-" * (name_col + items_col + (3 * time_col)))
for key, val in self._times.items():
_min = f"{np.min(val):.{decimal_places}f}"
avg = f"{np.mean(val):.{decimal_places}f}"
_max = f"{np.max(val):.{decimal_places}f}"
num = str(len(val))
print(f"{self._format_column(key, name_col)}{self._format_column(num, items_col)}"
f"{self._format_column(_min, time_col)}{self._format_column(avg, time_col)}"
f"{self._format_column(_max, time_col)}")
self._interval = 1