1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/utils.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

887 lines
31 KiB
Python

#!/usr/bin python3
""" Utilities available across all scripts """
from __future__ import annotations
import json
import logging
import os
import sys
import tkinter as tk
import typing as T
import warnings
import zipfile
from multiprocessing import current_process
from re import finditer
from socket import timeout as socket_timeout, error as socket_error
from threading import get_ident
from time import time
from urllib import request, error as urlliberror
import numpy as np
from tqdm import tqdm
if T.TYPE_CHECKING:
from http.client import HTTPResponse
# Global variables
_image_extensions = [ # pylint:disable=invalid-name
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
_video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"]
_TF_VERS: tuple[int, int] | None = None
ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
class _Backend(): # pylint:disable=too-few-public-methods
""" Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment
Variable.
If file doesn't exist and a variable hasn't been set, create the config file. """
def __init__(self) -> None:
self._backends: dict[str, ValidBackends] = {"1": "cpu",
"2": "directml",
"3": "nvidia",
"4": "apple_silicon",
"5": "rocm"}
self._valid_backends = list(self._backends.values())
self._config_file = self._get_config_file()
self.backend = self._get_backend()
@classmethod
def _get_config_file(cls) -> str:
""" Obtain the location of the main Faceswap configuration file.
Returns
-------
str
The path to the Faceswap configuration file
"""
pypath = os.path.dirname(os.path.realpath(sys.argv[0]))
config_file = os.path.join(pypath, "config", ".faceswap")
return config_file
def _get_backend(self) -> ValidBackends:
""" Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from
the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user
to select a backend.
Returns
-------
str
The backend configuration in use by Faceswap
"""
# Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ:
fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
assert fs_backend in T.get_args(ValidBackends), (
f"Faceswap backend must be one of {T.get_args(ValidBackends)}")
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
return fs_backend
# Intercept for sphinx docs build
if sys.argv[0].endswith("sphinx-build"):
return "nvidia"
if not os.path.isfile(self._config_file):
self._configure_backend()
while True:
try:
with open(self._config_file, "r", encoding="utf8") as cnf:
config = json.load(cnf)
break
except json.decoder.JSONDecodeError:
self._configure_backend()
continue
fs_backend = config.get("backend", "").lower()
if not fs_backend or fs_backend not in self._backends.values():
fs_backend = self._configure_backend()
if current_process().name == "MainProcess":
print(f"Setting Faceswap backend to {fs_backend.upper()}")
return fs_backend
def _configure_backend(self) -> ValidBackends:
""" Get user input to select the backend that Faceswap should use.
Returns
-------
str
The backend configuration in use by Faceswap
"""
print("First time configuration. Please select the required backend")
while True:
txt = ", ".join([": ".join([key, val.upper().replace("_", " ")])
for key, val in self._backends.items()])
selection = input(f"{txt}: ")
if selection not in self._backends:
print(f"'{selection}' is not a valid selection. Please try again")
continue
break
fs_backend = self._backends[selection]
config = {"backend": fs_backend}
with open(self._config_file, "w", encoding="utf8") as cnf:
json.dump(config, cnf)
print(f"Faceswap config written to: {self._config_file}")
return fs_backend
_FS_BACKEND: ValidBackends = _Backend().backend
def get_backend() -> ValidBackends:
""" Get the backend that Faceswap is currently configured to use.
Returns
-------
str
The backend configuration in use by Faceswap. One of ["cpu", "directml", "nvidia", "rocm",
"apple_silicon"]
Example
-------
>>> from lib.utils import get_backend
>>> get_backend()
'nvidia'
"""
return _FS_BACKEND
def set_backend(backend: str) -> None:
""" Override the configured backend with the given backend.
Parameters
----------
backend: ["cpu", "directml", "nvidia", "rocm", "apple_silicon"]
The backend to set faceswap to
Example
-------
>>> from lib.utils import set_backend
>>> set_backend("nvidia")
"""
global _FS_BACKEND # pylint:disable=global-statement
backend = T.cast(ValidBackends, backend.lower())
_FS_BACKEND = backend
def get_tf_version() -> tuple[int, int]:
""" Obtain the major. minor version of currently installed Tensorflow.
Returns
-------
tuple[int, int]
A tuple of the form (major, minor) representing the version of TensorFlow that is installed
Example
-------
>>> from lib.utils import get_tf_version
>>> get_tf_version()
(2, 10)
"""
global _TF_VERS # pylint:disable=global-statement
if _TF_VERS is None:
import tensorflow as tf # pylint:disable=import-outside-toplevel
split = tf.__version__.split(".")[:2]
_TF_VERS = (int(split[0]), int(split[1]))
return _TF_VERS
def get_folder(path: str, make_folder: bool = True) -> str:
""" Return a path to a folder, creating it if it doesn't exist
Parameters
----------
path: str
The path to the folder to obtain
make_folder: bool, optional
``True`` if the folder should be created if it does not already exist, ``False`` if the
folder should not be created
Returns
-------
str or `None`
The path to the requested folder. If `make_folder` is set to ``False`` and the requested
path does not exist, then ``None`` is returned
Example
-------
>>> from lib.utils import get_folder
>>> get_folder('/tmp/myfolder')
'/tmp/myfolder'
>>> get_folder('/tmp/myfolder', make_folder=False)
''
"""
logger = logging.getLogger(__name__)
logger.debug("Requested path: '%s'", path)
if not make_folder and not os.path.isdir(path):
logger.debug("%s does not exist", path)
return ""
os.makedirs(path, exist_ok=True)
logger.debug("Returning: '%s'", path)
return path
def get_image_paths(directory: str, extension: str | None = None) -> list[str]:
""" Gets the image paths from a given directory.
The function searches for files with the specified extension(s) in the given directory, and
returns a list of their paths. If no extension is provided, the function will search for files
with any of the following extensions: '.bmp', '.jpeg', '.jpg', '.png', '.tif', '.tiff'
Parameters
----------
directory: str
The directory to search in
extension: str
The file extension to search for. If not provided, all image file types will be searched
for
Returns
-------
list[str]
The list of full paths to the images contained within the given folder
Example
-------
>>> from lib.utils import get_image_paths
>>> get_image_paths('/path/to/directory')
['/path/to/directory/image1.jpg', '/path/to/directory/image2.png']
>>> get_image_paths('/path/to/directory', '.jpg')
['/path/to/directory/image1.jpg']
"""
logger = logging.getLogger(__name__)
image_extensions = _image_extensions if extension is None else [extension]
dir_contents = []
if not os.path.exists(directory):
logger.debug("Creating folder: '%s'", directory)
directory = get_folder(directory)
dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name)
logger.debug("Scanned Folder contains %s files", len(dir_scanned))
logger.trace("Scanned Folder Contents: %s", dir_scanned) # type:ignore[attr-defined]
for chkfile in dir_scanned:
if any(chkfile.name.lower().endswith(ext) for ext in image_extensions):
logger.trace("Adding '%s' to image list", chkfile.path) # type:ignore[attr-defined]
dir_contents.append(chkfile.path)
logger.debug("Returning %s images", len(dir_contents))
return dir_contents
def get_dpi() -> float | None:
""" Gets the DPI (dots per inch) of the display screen.
Returns
-------
float or ``None``
The DPI of the display screen or ``None`` if the dpi couldn't be obtained (ie: if the
function is called on a headless system)
Example
-------
>>> from lib.utils import get_dpi
>>> get_dpi()
96.0
"""
logger = logging.getLogger(__name__)
try:
root = tk.Tk()
dpi = root.winfo_fpixels('1i')
except tk.TclError:
logger.warning("Display not detected. Could not obtain DPI")
return None
return float(dpi)
def convert_to_secs(*args: int) -> int:
""" Convert time in hours, minutes, and seconds to seconds.
Parameters
----------
*args: int
1, 2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints
are supplied then (`hours`, `minutes`, `seconds`) is implied.
Returns
-------
int
The given time converted to seconds
Example
-------
>>> from lib.utils import convert_to_secs
>>> convert_to_secs(1, 30, 0)
5400
>>> convert_to_secs(0, 15, 30)
930
>>> convert_to_secs(0, 0, 45)
45
"""
logger = logging.getLogger(__name__)
logger.debug("from time: %s", args)
retval = 0.0
if len(args) == 1:
retval = float(args[0])
elif len(args) == 2:
retval = 60 * float(args[0]) + float(args[1])
elif len(args) == 3:
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
retval = int(retval)
logger.debug("to secs: %s", retval)
return retval
def full_path_split(path: str) -> list[str]:
""" Split a file path into all of its parts.
Parameters
----------
path: str
The full path to be split
Returns
-------
list
The full path split into a separate item for each part
Example
-------
>>> from lib.utils import full_path_split
>>> full_path_split("/usr/local/bin/python")
['usr', 'local', 'bin', 'python']
>>> full_path_split("relative/path/to/file.txt")
['relative', 'path', 'to', 'file.txt']]
"""
logger = logging.getLogger(__name__)
allparts: list[str] = []
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
allparts.insert(0, parts[0])
break
if parts[1] == path: # sentinel for relative paths
allparts.insert(0, parts[1])
break
path = parts[0]
allparts.insert(0, parts[1])
logger.trace("path: %s, allparts: %s", path, allparts) # type:ignore[attr-defined]
# Remove any empty strings which may have got inserted
allparts = [part for part in allparts if part]
return allparts
def set_system_verbosity(log_level: str):
""" Set the verbosity level of tensorflow and suppresses future and deprecation warnings from
any modules.
This function sets the `TF_CPP_MIN_LOG_LEVEL` environment variable to control the verbosity of
TensorFlow output, as well as filters certain warning types to be ignored. The log level is
determined based on the input string `log_level`.
Parameters
----------
log_level: str
The requested Faceswap log level.
References
----------
https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
Example
-------
>>> from lib.utils import set_system_verbosity
>>> set_system_verbosity('warning')
"""
logger = logging.getLogger(__name__)
from lib.logger import get_loglevel # pylint:disable=import-outside-toplevel
numeric_level = get_loglevel(log_level)
log_level = "3" if numeric_level > 15 else "0"
logger.debug("System Verbosity level: %s", log_level)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = log_level
if log_level != '0':
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.simplefilter(action='ignore', category=warncat)
def deprecation_warning(function: str, additional_info: str | None = None) -> None:
""" Log a deprecation warning message.
This function logs a warning message to indicate that the specified function has been
deprecated and will be removed in future. An optional additional message can also be included.
Parameters
----------
function: str
The name of the function that will be deprecated.
additional_info: str, optional
Any additional information to display with the deprecation message. Default: ``None``
Example
-------
>>> from lib.utils import deprecation_warning
>>> deprecation_warning('old_function', 'Use new_function instead.')
"""
logger = logging.getLogger(__name__)
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
msg = f"{function} has been deprecated and will be removed from a future update."
if additional_info is not None:
msg += f" {additional_info}"
logger.warning(msg)
def camel_case_split(identifier: str) -> list[str]:
""" Split a camelCase string into a list of its individual parts
Parameters
----------
identifier: str
The camelCase text to be split
Returns
-------
list[str]
A list of the individual parts of the camelCase string.
References
----------
https://stackoverflow.com/questions/29916065
Example
-------
>>> from lib.utils import camel_case_split
>>> camel_case_split('camelCaseExample')
['camel', 'Case', 'Example']
"""
matches = finditer(
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)",
identifier)
return [m.group(0) for m in matches]
def safe_shutdown(got_error: bool = False) -> None:
""" Safely shut down the system.
This function terminates the queue manager and exits the program in a clean and orderly manner.
An optional boolean parameter can be used to indicate whether an error occurred during the
program's execution.
Parameters
----------
got_error: bool, optional
``True`` if this function is being called as the result of raised error. Default: ``False``
Example
-------
>>> from lib.utils import safe_shutdown
>>> safe_shutdown()
>>> safe_shutdown(True)
"""
logger = logging.getLogger(__name__)
logger.debug("Safely shutting down")
from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel
queue_manager.terminate_queues()
logger.debug("Cleanup complete. Shutting down queue manager and exiting")
sys.exit(1 if got_error else 0)
class FaceswapError(Exception):
""" Faceswap Error for handling specific errors with useful information.
Raises
------
FaceswapError
on a captured error
Example
-------
>>> from lib.utils import FaceswapError
>>> try:
... # Some code that may raise an error
... except SomeError:
... raise FaceswapError("There was an error while running the code")
FaceswapError: There was an error while running the code
"""
pass # pylint:disable=unnecessary-pass
class GetModel(): # pylint:disable=too-few-public-methods
""" Check for models in the cache path.
If available, return the path, if not available, get, unzip and install model
Parameters
----------
model_filename: str or list
The name of the model to be loaded (see notes below)
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
Notes
------
Models must have a certain naming convention: `<model_name>_v<version_number>.<extension>`
(eg: `s3fd_v1.pb`).
Multiple models can exist within the model_filename. They should be passed as a list and follow
the same naming convention as above. Any differences in filename should occur AFTER the version
number: `<model_name>_v<version_number><differentiating_information>.<extension>` (eg:
`["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel"
,"resnet_ssd_v1.prototext"]`
Example
-------
>>> from lib.utils import GetModel
>>> model_downloader = GetModel("s3fd_keras_v2.h5", 11)
"""
def __init__(self, model_filename: str | list[str], git_model_id: int) -> None:
self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list):
model_filename = [model_filename]
self._model_filename = model_filename
self._cache_dir = os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), ".fs_cache")
self._git_model_id = git_model_id
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
self._chunk_size = 1024 # Chunk size for downloading and unzipping
self._retries = 6
self._get()
@property
def _model_full_name(self) -> str:
""" str: The full model name from the filename(s). """
common_prefix = os.path.commonprefix(self._model_filename)
retval = os.path.splitext(common_prefix)[0]
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def _model_name(self) -> str:
""" str: The model name from the model's full name. """
retval = self._model_full_name[:self._model_full_name.rfind("_")]
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def _model_version(self) -> int:
""" int: The model's version number from the model full name. """
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def model_path(self) -> str | list[str]:
""" str or list[str]: The model path(s) in the cache folder.
Example
-------
>>> from lib.utils import GetModel
>>> model_downloader = GetModel("s3fd_keras_v2.h5", 11)
>>> model_downloader.model_path
'/path/to/s3fd_keras_v2.h5'
"""
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
retval: str | list[str] = paths[0] if len(paths) == 1 else paths
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def _model_zip_path(self) -> str:
""" str: The full path to downloaded zip file. """
retval = os.path.join(self._cache_dir, f"{self._model_full_name}.zip")
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def _model_exists(self) -> bool:
""" bool: ``True`` if the model exists in the cache folder otherwise ``False``. """
if isinstance(self.model_path, list):
retval = all(os.path.exists(pth) for pth in self.model_path)
else:
retval = os.path.exists(self.model_path)
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@property
def _url_download(self) -> str:
""" strL Base download URL for models. """
tag = f"v{self._git_model_id}.{self._model_version}"
retval = f"{self._url_base}/{tag}/{self._model_full_name}.zip"
self.logger.trace("Download url: %s", retval) # type:ignore[attr-defined]
return retval
@property
def _url_partial_size(self) -> int:
""" int: How many bytes have already been downloaded. """
zip_file = self._model_zip_path
retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
def _get(self) -> None:
""" Check the model exists, if not, download the model, unzip it and place it in the
model's cache folder. """
if self._model_exists:
self.logger.debug("Model exists: %s", self.model_path)
return
self._download_model()
self._unzip_model()
os.remove(self._model_zip_path)
def _download_model(self) -> None:
""" Download the model zip from github to the cache folder. """
self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download)
for attempt in range(self._retries):
try:
downloaded_size = self._url_partial_size
req = request.Request(self._url_download)
if downloaded_size != 0:
req.add_header("Range", f"bytes={downloaded_size}-")
with request.urlopen(req, timeout=10) as response:
self.logger.debug("header info: {%s}", response.info())
self.logger.debug("Return Code: %s", response.getcode())
self._write_zipfile(response, downloaded_size)
break
except (socket_error, socket_timeout,
urlliberror.HTTPError, urlliberror.URLError) as err:
if attempt + 1 < self._retries:
self.logger.warning("Error downloading model (%s). Retrying %s of %s...",
str(err), attempt + 2, self._retries)
else:
self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: "
"'%s')", str(err), self._url_download)
self.logger.info("You can try running again to resume the download.")
self.logger.info("Alternatively, you can manually download the model from: %s "
"and unzip the contents to: %s",
self._url_download, self._cache_dir)
sys.exit(1)
def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None:
""" Write the model zip file to disk.
Parameters
----------
response: :class:`http.client.HTTPResponse`
The response from the model download task
downloaded_size: int
The amount of bytes downloaded so far
"""
content_length = response.getheader("content-length")
content_length = "0" if content_length is None else content_length
length = int(content_length) + downloaded_size
if length == downloaded_size:
self.logger.info("Zip already exists. Skipping download")
return
write_type = "wb" if downloaded_size == 0 else "ab"
with open(self._model_zip_path, write_type) as out_file:
pbar = tqdm(desc="Downloading",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
if downloaded_size != 0:
pbar.update(downloaded_size)
while True:
buffer = response.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
pbar.close()
def _unzip_model(self) -> None:
""" Unzip the model file to the cache folder """
self.logger.info("Extracting: '%s'", self._model_name)
try:
with zipfile.ZipFile(self._model_zip_path, "r") as zip_file:
self._write_model(zip_file)
except Exception as err: # pylint:disable=broad-except
self.logger.error("Unable to extract model file: %s", str(err))
sys.exit(1)
def _write_model(self, zip_file: zipfile.ZipFile) -> None:
""" Extract files from zip file and write, with progress bar.
Parameters
----------
zip_file: :class:`zipfile.ZipFile`
The downloaded model zip file
"""
length = sum(f.file_size for f in zip_file.infolist())
fnames = zip_file.namelist()
self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length)
pbar = tqdm(desc="Decompressing",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
for fname in fnames:
out_fname = os.path.join(self._cache_dir, fname)
self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname)
zipped = zip_file.open(fname)
with open(out_fname, "wb") as out_file:
while True:
buffer = zipped.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
pbar.close()
class DebugTimes():
""" A simple tool to help debug timings.
Parameters
----------
min: bool, Optional
Display minimum time taken in summary stats. Default: ``True``
mean: bool, Optional
Display mean time taken in summary stats. Default: ``True``
max: bool, Optional
Display maximum time taken in summary stats. Default: ``True``
Example
-------
>>> from lib.utils import DebugTimes
>>> debug_times = DebugTimes()
>>> debug_times.step_start("step 1")
>>> # do something here
>>> debug_times.step_end("step 1")
>>> debug_times.summary()
----------------------------------
Step Count Min
----------------------------------
step 1 1 0.000000
"""
def __init__(self,
show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None:
self._times: dict[str, list[float]] = {}
self._steps: dict[str, float] = {}
self._interval = 1
self._display = {"min": show_min, "mean": show_mean, "max": show_max}
def step_start(self, name: str, record: bool = True) -> None:
""" Start the timer for the given step name.
Parameters
----------
name: str
The name of the step to start the timer for
record: bool, optional
``True`` to record the step time, ``False`` to not record it.
Used for when you have conditional code to time, but do not want to insert if/else
statements in the code. Default: `True`
Example
-------
>>> from lib.util import DebugTimes
>>> debug_times = DebugTimes()
>>> debug_times.step_start("Example Step")
>>> # do something here
>>> debug_times.step_end("Example Step")
"""
if not record:
return
storename = name + str(get_ident())
self._steps[storename] = time()
def step_end(self, name: str, record: bool = True) -> None:
""" Stop the timer and record elapsed time for the given step name.
Parameters
----------
name: str
The name of the step to end the timer for
record: bool, optional
``True`` to record the step time, ``False`` to not record it.
Used for when you have conditional code to time, but do not want to insert if/else
statements in the code. Default: `True`
Example
-------
>>> from lib.util import DebugTimes
>>> debug_times = DebugTimes()
>>> debug_times.step_start("Example Step")
>>> # do something here
>>> debug_times.step_end("Example Step")
"""
if not record:
return
storename = name + str(get_ident())
self._times.setdefault(name, []).append(time() - self._steps.pop(storename))
@classmethod
def _format_column(cls, text: str, width: int) -> str:
""" Pad the given text to be aligned to the given width.
Parameters
----------
text: str
The text to be formatted
width: int
The size of the column to insert the text into
Returns
-------
str
The text with the correct amount of padding applied
"""
return f"{text}{' ' * (width - len(text))}"
def summary(self, decimal_places: int = 6, interval: int = 1) -> None:
""" Print a summary of step times.
Parameters
----------
decimal_places: int, optional
The number of decimal places to display the summary elapsed times to. Default: 6
interval: int, optional
How many times summary must be called before printing to console. Default: 1
Example
-------
>>> from lib.utils import DebugTimes
>>> debug = DebugTimes()
>>> debug.step_start("test")
>>> time.sleep(0.5)
>>> debug.step_end("test")
>>> debug.summary()
----------------------------------
Step Count Min
----------------------------------
test 1 0.500000
"""
interval = max(1, interval)
if interval != self._interval:
self._interval += 1
return
name_col = max(len(key) for key in self._times) + 4
items_col = 8
time_col = (decimal_places + 4) * sum(1 for v in self._display.values() if v)
separator = "-" * (name_col + items_col + time_col)
print("")
print(separator)
header = (f"{self._format_column('Step', name_col)}"
f"{self._format_column('Count', items_col)}")
header += f"{self._format_column('Min', time_col)}" if self._display["min"] else ""
header += f"{self._format_column('Avg', time_col)}" if self._display["mean"] else ""
header += f"{self._format_column('Max', time_col)}" if self._display["max"] else ""
print(header)
print(separator)
for key, val in self._times.items():
num = str(len(val))
contents = f"{self._format_column(key, name_col)}{self._format_column(num, items_col)}"
if self._display["min"]:
_min = f"{np.min(val):.{decimal_places}f}"
contents += f"{self._format_column(_min, time_col)}"
if self._display["mean"]:
avg = f"{np.mean(val):.{decimal_places}f}"
contents += f"{self._format_column(avg, time_col)}"
if self._display["max"]:
_max = f"{np.max(val):.{decimal_places}f}"
contents += f"{self._format_column(_max, time_col)}"
print(contents)
self._interval = 1