mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-08 11:53:26 -04:00
705 lines
25 KiB
Python
705 lines
25 KiB
Python
#!/usr/bin python3
|
|
""" Utilities available across all scripts """
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tkinter as tk
|
|
import urllib
|
|
import warnings
|
|
import zipfile
|
|
|
|
from re import finditer
|
|
from multiprocessing import current_process
|
|
from socket import timeout as socket_timeout, error as socket_error
|
|
from time import time
|
|
from typing import cast, List, Optional, Union, TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
if sys.version_info < (3, 8):
|
|
from typing_extensions import get_args, Literal
|
|
else:
|
|
from typing import get_args, Literal
|
|
|
|
if TYPE_CHECKING:
|
|
from http.client import HTTPResponse
|
|
|
|
# Global variables
|
|
_image_extensions = [ # pylint:disable=invalid-name
|
|
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
|
|
_video_extensions = [ # pylint:disable=invalid-name
|
|
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
|
|
".ts", ".vob"]
|
|
_TF_VERS = None
|
|
ValidBackends = Literal["amd", "nvidia", "cpu", "apple_silicon"]
|
|
|
|
|
|
class _Backend(): # pylint:disable=too-few-public-methods
|
|
""" Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment
|
|
Variable.
|
|
|
|
If file doesn't exist and a variable hasn't been set, create the config file. """
|
|
def __init__(self) -> None:
|
|
self._backends = {"1": "amd", "2": "cpu", "3": "nvidia", "4": "apple_silicon"}
|
|
self._valid_backends = list(self._backends.values())
|
|
self._config_file = self._get_config_file()
|
|
self.backend = self._get_backend()
|
|
|
|
@classmethod
|
|
def _get_config_file(cls) -> str:
|
|
""" Obtain the location of the main Faceswap configuration file.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The path to the Faceswap configuration file
|
|
"""
|
|
pypath = os.path.dirname(os.path.realpath(sys.argv[0]))
|
|
config_file = os.path.join(pypath, "config", ".faceswap")
|
|
return config_file
|
|
|
|
def _get_backend(self) -> ValidBackends:
|
|
""" Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from
|
|
the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user
|
|
to select a backend.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
# Check if environment variable is set, if so use that
|
|
if "FACESWAP_BACKEND" in os.environ:
|
|
fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
|
|
assert fs_backend in get_args(ValidBackends), (
|
|
f"Faceswap backend must be one of {get_args(ValidBackends)}")
|
|
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
|
|
return fs_backend
|
|
# Intercept for sphinx docs build
|
|
if sys.argv[0].endswith("sphinx-build"):
|
|
return "nvidia"
|
|
if not os.path.isfile(self._config_file):
|
|
self._configure_backend()
|
|
while True:
|
|
try:
|
|
with open(self._config_file, "r", encoding="utf8") as cnf:
|
|
config = json.load(cnf)
|
|
break
|
|
except json.decoder.JSONDecodeError:
|
|
self._configure_backend()
|
|
continue
|
|
fs_backend = config.get("backend", "").lower()
|
|
if not fs_backend or fs_backend not in self._backends.values():
|
|
fs_backend = self._configure_backend()
|
|
if current_process().name == "MainProcess":
|
|
print(f"Setting Faceswap backend to {fs_backend.upper()}")
|
|
return fs_backend
|
|
|
|
def _configure_backend(self) -> ValidBackends:
|
|
""" Get user input to select the backend that Faceswap should use.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
print("First time configuration. Please select the required backend")
|
|
while True:
|
|
selection = input("1: AMD, 2: CPU, 3: NVIDIA, 4: APPLE SILICON: ")
|
|
if selection not in ("1", "2", "3", "4"):
|
|
print(f"'{selection}' is not a valid selection. Please try again")
|
|
continue
|
|
break
|
|
fs_backend = cast(ValidBackends, self._backends[selection].lower())
|
|
config = {"backend": fs_backend}
|
|
with open(self._config_file, "w", encoding="utf8") as cnf:
|
|
json.dump(config, cnf)
|
|
print(f"Faceswap config written to: {self._config_file}")
|
|
return fs_backend
|
|
|
|
|
|
_FS_BACKEND: ValidBackends = _Backend().backend
|
|
|
|
|
|
def get_backend() -> ValidBackends:
|
|
""" Get the backend that Faceswap is currently configured to use.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The backend configuration in use by Faceswap
|
|
"""
|
|
return _FS_BACKEND
|
|
|
|
|
|
def set_backend(backend: str) -> None:
|
|
""" Override the configured backend with the given backend.
|
|
|
|
Parameters
|
|
----------
|
|
backend: ["amd", "cpu", "nvidia", "apple_silicon"]
|
|
The backend to set faceswap to
|
|
"""
|
|
global _FS_BACKEND # pylint:disable=global-statement
|
|
backend = cast(ValidBackends, backend.lower())
|
|
_FS_BACKEND = backend
|
|
|
|
|
|
def get_tf_version() -> float:
|
|
""" Obtain the major.minor version of currently installed Tensorflow.
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
The currently installed tensorflow version
|
|
"""
|
|
global _TF_VERS # pylint:disable=global-statement
|
|
if _TF_VERS is None:
|
|
import tensorflow as tf # pylint:disable=import-outside-toplevel
|
|
_TF_VERS = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
|
|
return _TF_VERS
|
|
|
|
|
|
def get_folder(path: str, make_folder: bool = True) -> str:
|
|
""" Return a path to a folder, creating it if it doesn't exist
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
The path to the folder to obtain
|
|
make_folder: bool, optional
|
|
``True`` if the folder should be created if it does not already exist, ``False`` if the
|
|
folder should not be created
|
|
|
|
Returns
|
|
-------
|
|
str or `None`
|
|
The path to the requested folder. If `make_folder` is set to ``False`` and the requested
|
|
path does not exist, then ``None`` is returned
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("Requested path: '%s'", path)
|
|
if not make_folder and not os.path.isdir(path):
|
|
logger.debug("%s does not exist", path)
|
|
return ""
|
|
os.makedirs(path, exist_ok=True)
|
|
logger.debug("Returning: '%s'", path)
|
|
return path
|
|
|
|
|
|
def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]:
|
|
""" Obtain a list of full paths that reside within a folder.
|
|
|
|
Parameters
|
|
----------
|
|
directory: str
|
|
The folder that contains the images to be returned
|
|
extension: str
|
|
The specific image extensions that should be returned
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The list of full paths to the images contained within the given folder
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
image_extensions = _image_extensions if extension is None else [extension]
|
|
dir_contents = []
|
|
|
|
if not os.path.exists(directory):
|
|
logger.debug("Creating folder: '%s'", directory)
|
|
directory = get_folder(directory)
|
|
|
|
dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name)
|
|
logger.debug("Scanned Folder contains %s files", len(dir_scanned))
|
|
logger.trace("Scanned Folder Contents: %s", dir_scanned) # type:ignore
|
|
|
|
for chkfile in dir_scanned:
|
|
if any(chkfile.name.lower().endswith(ext) for ext in image_extensions):
|
|
logger.trace("Adding '%s' to image list", chkfile.path) # type:ignore
|
|
dir_contents.append(chkfile.path)
|
|
|
|
logger.debug("Returning %s images", len(dir_contents))
|
|
return dir_contents
|
|
|
|
|
|
def get_dpi() -> float:
|
|
""" Obtain the DPI of the running screen.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
The obtain dots per inch of the running monitor
|
|
"""
|
|
root = tk.Tk()
|
|
dpi = root.winfo_fpixels('1i')
|
|
return float(dpi)
|
|
|
|
|
|
def convert_to_secs(*args: int) -> int:
|
|
""" Convert a time to seconds.
|
|
|
|
Parameters
|
|
----------
|
|
args: tuple
|
|
2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints are
|
|
supplied then (`hours`, `minutes`, `seconds`) is implied.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
The given time converted to seconds
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("from time: %s", args)
|
|
retval = 0.0
|
|
if len(args) == 1:
|
|
retval = float(args[0])
|
|
elif len(args) == 2:
|
|
retval = 60 * float(args[0]) + float(args[1])
|
|
elif len(args) == 3:
|
|
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
|
|
retval = int(retval)
|
|
logger.debug("to secs: %s", retval)
|
|
return retval
|
|
|
|
|
|
def full_path_split(path: str) -> List[str]:
|
|
""" Split a full path to a location into all of it's separate components.
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
The full path to be split
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The full path split into a separate item for each part
|
|
|
|
Example
|
|
-------
|
|
>>> path = "/foo/baz/bar"
|
|
>>> full_path_split(path)
|
|
>>> ["foo", "baz", "bar"]
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
allparts: List[str] = []
|
|
while True:
|
|
parts = os.path.split(path)
|
|
if parts[0] == path: # sentinel for absolute paths
|
|
allparts.insert(0, parts[0])
|
|
break
|
|
if parts[1] == path: # sentinel for relative paths
|
|
allparts.insert(0, parts[1])
|
|
break
|
|
path = parts[0]
|
|
allparts.insert(0, parts[1])
|
|
logger.trace("path: %s, allparts: %s", path, allparts) # type:ignore
|
|
return allparts
|
|
|
|
|
|
def set_system_verbosity(log_level: str):
|
|
""" Set the verbosity level of tensorflow and suppresses future and deprecation warnings from
|
|
any modules
|
|
|
|
Parameters
|
|
----------
|
|
log_level: str
|
|
The requested Faceswap log level
|
|
|
|
References
|
|
----------
|
|
https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
|
|
Can be set to:
|
|
0: all logs shown. 1: filter out INFO logs. 2: filter out WARNING logs. 3: filter out ERROR
|
|
logs.
|
|
"""
|
|
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
from lib.logger import get_loglevel # pylint:disable=import-outside-toplevel
|
|
numeric_level = get_loglevel(log_level)
|
|
log_level = "3" if numeric_level > 15 else "0"
|
|
logger.debug("System Verbosity level: %s", log_level)
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = log_level
|
|
if log_level != '0':
|
|
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
|
|
warnings.simplefilter(action='ignore', category=warncat)
|
|
|
|
|
|
def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None:
|
|
""" Log at warning level that a function will be removed in a future update.
|
|
|
|
Parameters
|
|
----------
|
|
function: str
|
|
The function that will be deprecated.
|
|
additional_info: str, optional
|
|
Any additional information to display with the deprecation message. Default: ``None``
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
|
|
msg = f"{function} has been deprecated and will be removed from a future update."
|
|
if additional_info is not None:
|
|
msg += f" {additional_info}"
|
|
logger.warning(msg)
|
|
|
|
|
|
def camel_case_split(identifier: str) -> List[str]:
|
|
""" Split a camel case name
|
|
|
|
Parameters
|
|
----------
|
|
identifier: str
|
|
The camel case text to be split
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
A list of the given identifier split into it's constituent parts
|
|
|
|
|
|
References
|
|
----------
|
|
https://stackoverflow.com/questions/29916065
|
|
"""
|
|
matches = finditer(
|
|
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)",
|
|
identifier)
|
|
return [m.group(0) for m in matches]
|
|
|
|
|
|
def safe_shutdown(got_error: bool = False) -> None:
|
|
""" Close all tracked queues and threads in event of crash or on shut down.
|
|
|
|
Parameters
|
|
----------
|
|
got_error: bool, optional
|
|
``True`` if this function is being called as the result of raised error, otherwise
|
|
``False``. Default: ``False``
|
|
"""
|
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
logger.debug("Safely shutting down")
|
|
from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel
|
|
queue_manager.terminate_queues()
|
|
logger.debug("Cleanup complete. Shutting down queue manager and exiting")
|
|
sys.exit(1 if got_error else 0)
|
|
|
|
|
|
class FaceswapError(Exception):
|
|
""" Faceswap Error for handling specific errors with useful information.
|
|
|
|
Raises
|
|
------
|
|
FaceswapError
|
|
on a captured error
|
|
"""
|
|
pass # pylint:disable=unnecessary-pass
|
|
|
|
|
|
class GetModel(): # pylint:disable=too-few-public-methods
|
|
""" Check for models in the cache path.
|
|
|
|
If available, return the path, if not available, get, unzip and install model
|
|
|
|
Parameters
|
|
----------
|
|
model_filename: str or list
|
|
The name of the model to be loaded (see notes below)
|
|
git_model_id: int
|
|
The second digit in the github tag that identifies this model. See
|
|
https://github.com/deepfakes-models/faceswap-models for more information
|
|
|
|
Notes
|
|
------
|
|
Models must have a certain naming convention: `<model_name>_v<version_number>.<extension>`
|
|
(eg: `s3fd_v1.pb`).
|
|
|
|
Multiple models can exist within the model_filename. They should be passed as a list and follow
|
|
the same naming convention as above. Any differences in filename should occur AFTER the version
|
|
number: `<model_name>_v<version_number><differentiating_information>.<extension>` (eg:
|
|
`["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel"
|
|
,"resnet_ssd_v1.prototext"]`
|
|
"""
|
|
|
|
def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None:
|
|
self.logger = logging.getLogger(__name__)
|
|
if not isinstance(model_filename, list):
|
|
model_filename = [model_filename]
|
|
self._model_filename = model_filename
|
|
self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
|
|
self._git_model_id = git_model_id
|
|
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
|
|
self._chunk_size = 1024 # Chunk size for downloading and unzipping
|
|
self._retries = 6
|
|
self._get()
|
|
|
|
@property
|
|
def _model_full_name(self) -> str:
|
|
""" str: The full model name from the filename(s). """
|
|
common_prefix = os.path.commonprefix(self._model_filename)
|
|
retval = os.path.splitext(common_prefix)[0]
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _model_name(self) -> str:
|
|
""" str: The model name from the model's full name. """
|
|
retval = self._model_full_name[:self._model_full_name.rfind("_")]
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _model_version(self) -> int:
|
|
""" int: The model's version number from the model full name. """
|
|
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def model_path(self) -> Union[str, List[str]]:
|
|
""" str or list: The model path(s) in the cache folder. """
|
|
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
|
|
retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _model_zip_path(self) -> str:
|
|
""" str: The full path to downloaded zip file. """
|
|
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _model_exists(self) -> bool:
|
|
""" bool: ``True`` if the model exists in the cache folder otherwise ``False``. """
|
|
if isinstance(self.model_path, list):
|
|
retval = all(os.path.exists(pth) for pth in self.model_path)
|
|
else:
|
|
retval = os.path.exists(self.model_path)
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _url_download(self) -> str:
|
|
""" strL Base download URL for models. """
|
|
tag = f"v{self._git_model_id}.{self._model_version}"
|
|
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
|
|
self.logger.trace("Download url: %s", retval) # type: ignore
|
|
return retval
|
|
|
|
@property
|
|
def _url_partial_size(self) -> int:
|
|
""" int: How many bytes have already been downloaded. """
|
|
zip_file = self._model_zip_path
|
|
retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0
|
|
self.logger.trace(retval) # type: ignore
|
|
return retval
|
|
|
|
def _get(self) -> None:
|
|
""" Check the model exists, if not, download the model, unzip it and place it in the
|
|
model's cache folder. """
|
|
if self._model_exists:
|
|
self.logger.debug("Model exists: %s", self.model_path)
|
|
return
|
|
self._download_model()
|
|
self._unzip_model()
|
|
os.remove(self._model_zip_path)
|
|
|
|
def _download_model(self) -> None:
|
|
""" Download the model zip from github to the cache folder. """
|
|
self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download)
|
|
for attempt in range(self._retries):
|
|
try:
|
|
downloaded_size = self._url_partial_size
|
|
req = urllib.request.Request(self._url_download)
|
|
if downloaded_size != 0:
|
|
req.add_header("Range", f"bytes={downloaded_size}-")
|
|
with urllib.request.urlopen(req, timeout=10) as response:
|
|
self.logger.debug("header info: {%s}", response.info())
|
|
self.logger.debug("Return Code: %s", response.getcode())
|
|
self._write_zipfile(response, downloaded_size)
|
|
break
|
|
except (socket_error, socket_timeout,
|
|
urllib.error.HTTPError, urllib.error.URLError) as err:
|
|
if attempt + 1 < self._retries:
|
|
self.logger.warning("Error downloading model (%s). Retrying %s of %s...",
|
|
str(err), attempt + 2, self._retries)
|
|
else:
|
|
self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: "
|
|
"'%s')", str(err), self._url_download)
|
|
self.logger.info("You can try running again to resume the download.")
|
|
self.logger.info("Alternatively, you can manually download the model from: %s "
|
|
"and unzip the contents to: %s",
|
|
self._url_download, self._cache_dir)
|
|
sys.exit(1)
|
|
|
|
def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None:
|
|
""" Write the model zip file to disk.
|
|
|
|
Parameters
|
|
----------
|
|
response: :class:`http.client.HTTPResponse`
|
|
The response from the model download task
|
|
downloaded_size: int
|
|
The amount of bytes downloaded so far
|
|
"""
|
|
content_length = response.getheader("content-length")
|
|
content_length = "0" if content_length is None else content_length
|
|
length = int(content_length) + downloaded_size
|
|
if length == downloaded_size:
|
|
self.logger.info("Zip already exists. Skipping download")
|
|
return
|
|
write_type = "wb" if downloaded_size == 0 else "ab"
|
|
with open(self._model_zip_path, write_type) as out_file:
|
|
pbar = tqdm(desc="Downloading",
|
|
unit="B",
|
|
total=length,
|
|
unit_scale=True,
|
|
unit_divisor=1024)
|
|
if downloaded_size != 0:
|
|
pbar.update(downloaded_size)
|
|
while True:
|
|
buffer = response.read(self._chunk_size)
|
|
if not buffer:
|
|
break
|
|
pbar.update(len(buffer))
|
|
out_file.write(buffer)
|
|
pbar.close()
|
|
|
|
def _unzip_model(self) -> None:
|
|
""" Unzip the model file to the cache folder """
|
|
self.logger.info("Extracting: '%s'", self._model_name)
|
|
try:
|
|
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
|
|
self._write_model(zip_file)
|
|
except Exception as err: # pylint:disable=broad-except
|
|
self.logger.error("Unable to extract model file: %s", str(err))
|
|
sys.exit(1)
|
|
|
|
def _write_model(self, zip_file: zipfile.ZipFile) -> None:
|
|
""" Extract files from zip file and write, with progress bar.
|
|
|
|
Parameters
|
|
----------
|
|
zip_file: :class:`zipfile.ZipFile`
|
|
The downloaded model zip file
|
|
"""
|
|
length = sum(f.file_size for f in zip_file.infolist())
|
|
fnames = zip_file.namelist()
|
|
self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length)
|
|
pbar = tqdm(desc="Decompressing",
|
|
unit="B",
|
|
total=length,
|
|
unit_scale=True,
|
|
unit_divisor=1024)
|
|
for fname in fnames:
|
|
out_fname = os.path.join(self._cache_dir, fname)
|
|
self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname)
|
|
zipped = zip_file.open(fname)
|
|
with open(out_fname, "wb") as out_file:
|
|
while True:
|
|
buffer = zipped.read(self._chunk_size)
|
|
if not buffer:
|
|
break
|
|
pbar.update(len(buffer))
|
|
out_file.write(buffer)
|
|
zip_file.close()
|
|
pbar.close()
|
|
|
|
|
|
class DebugTimes():
|
|
""" A simple tool to help debug timings.
|
|
"""
|
|
def __init__(self):
|
|
self._times = {}
|
|
self._steps = {}
|
|
self._interval = 1
|
|
|
|
def step_start(self, name: str, record: bool = True) -> None:
|
|
""" Start the timer for the given step name.
|
|
|
|
Parameters
|
|
----------
|
|
name: str
|
|
The name of the step to start the timer for
|
|
record: bool, optional
|
|
``True`` to record the step time, ``False`` to not record it.
|
|
Used for when you have conditional code to time, but do not want to insert if/else
|
|
statements in the code. Default: `True`
|
|
"""
|
|
if not record:
|
|
return
|
|
self._steps[name] = time()
|
|
|
|
def step_end(self, name: str, record: bool = True) -> None:
|
|
""" Stop the timer and record elapsed time for the given step name.
|
|
|
|
Parameters
|
|
----------
|
|
name: str
|
|
The name of the step to end the timer for
|
|
record: bool, optional
|
|
``True`` to record the step time, ``False`` to not record it.
|
|
Used for when you have conditional code to time, but do not want to insert if/else
|
|
statements in the code. Default: `True`
|
|
"""
|
|
if not record:
|
|
return
|
|
self._times.setdefault(name, []).append(time() - self._steps.pop(name))
|
|
|
|
@classmethod
|
|
def _format_column(cls, text: str, width: int) -> str:
|
|
""" Pad the given text to be aligned to the given width.
|
|
|
|
Parameters
|
|
----------
|
|
text: str
|
|
The text to be formatted
|
|
width: int
|
|
The size of the column to insert the text into
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The text with the correct amount of padding applied
|
|
"""
|
|
return f"{text}{' ' * (width - len(text))}"
|
|
|
|
def summary(self, decimal_places: int = 6, interval: int = 1) -> None:
|
|
""" Output a summary of step times.
|
|
|
|
Parameters
|
|
----------
|
|
decimal_places: int, optional
|
|
The number of decimal places to display the summary elapsed times to
|
|
interval: int, optional
|
|
How many times summary must be called before printing to console. Default: 1
|
|
"""
|
|
interval = max(1, interval)
|
|
if interval != self._interval:
|
|
self._interval += 1
|
|
return
|
|
|
|
name_col = max(len(key) for key in self._times) + 4
|
|
items_col = 8
|
|
time_col = decimal_places + 4
|
|
print("")
|
|
print("-" * (name_col + items_col + (3 * time_col)))
|
|
print(f"{self._format_column('Step', name_col)}{self._format_column('Count', items_col)}"
|
|
f"{self._format_column('Min', time_col)}{self._format_column('Avg', time_col)}"
|
|
f"{self._format_column('Max', time_col)}")
|
|
print("-" * (name_col + items_col + (3 * time_col)))
|
|
for key, val in self._times.items():
|
|
_min = f"{np.min(val):.{decimal_places}f}"
|
|
avg = f"{np.mean(val):.{decimal_places}f}"
|
|
_max = f"{np.max(val):.{decimal_places}f}"
|
|
num = str(len(val))
|
|
print(f"{self._format_column(key, name_col)}{self._format_column(num, items_col)}"
|
|
f"{self._format_column(_min, time_col)}{self._format_column(avg, time_col)}"
|
|
f"{self._format_column(_max, time_col)}")
|
|
self._interval = 1
|