#!/usr/bin python3 """ Utilities available across all scripts """ import logging import os import subprocess import sys import urllib import warnings import zipfile from hashlib import sha1 from pathlib import Path from re import finditer from socket import timeout as socket_timeout, error as socket_error import imageio_ffmpeg as im_ffm from tqdm import tqdm import numpy as np import cv2 from lib.faces_detect import DetectedFace # Global variables _image_extensions = [ # pylint:disable=invalid-name ".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"] _video_extensions = [ # pylint:disable=invalid-name ".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm"] def get_folder(path, make_folder=True): """ Return a path to a folder, creating it if it doesn't exist """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.debug("Requested path: '%s'", path) output_dir = Path(path) if not make_folder and not output_dir.exists(): logger.debug("%s does not exist", path) return None output_dir.mkdir(parents=True, exist_ok=True) logger.debug("Returning: '%s'", output_dir) return output_dir def get_image_paths(directory): """ Return a list of images that reside in a folder """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name image_extensions = _image_extensions dir_contents = list() if not os.path.exists(directory): logger.debug("Creating folder: '%s'", directory) directory = get_folder(directory) dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name) logger.debug("Scanned Folder contains %s files", len(dir_scanned)) logger.trace("Scanned Folder Contents: %s", dir_scanned) for chkfile in dir_scanned: if any([chkfile.name.lower().endswith(ext) for ext in image_extensions]): logger.trace("Adding '%s' to image list", chkfile.path) dir_contents.append(chkfile.path) logger.debug("Returning %s images", len(dir_contents)) return dir_contents def full_path_split(path): """ Split a given path into all of it's separate components """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name allparts = list() while True: parts = os.path.split(path) if parts[0] == path: # sentinel for absolute paths allparts.insert(0, parts[0]) break elif parts[1] == path: # sentinel for relative paths allparts.insert(0, parts[1]) break else: path = parts[0] allparts.insert(0, parts[1]) logger.trace("path: %s, allparts: %s", path, allparts) return allparts def cv2_read_img(filename, raise_error=False): """ Read an image with cv2 and check that an image was actually loaded. Logs an error if the image returned is None. or an error has occured. Pass raise_error=True if error should be raised """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.trace("Requested image: '%s'", filename) success = True image = None try: image = cv2.imread(filename) # pylint:disable=no-member,c-extension-no-member if image is None: raise ValueError except TypeError: success = False msg = "Error while reading image (TypeError): '{}'".format(filename) logger.error(msg) if raise_error: raise Exception(msg) except ValueError: success = False msg = ("Error while reading image. This is most likely caused by special characters in " "the filename: '{}'".format(filename)) logger.error(msg) if raise_error: raise Exception(msg) except Exception as err: # pylint:disable=broad-except success = False msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err)) logger.error(msg) if raise_error: raise Exception(msg) logger.trace("Loaded image: '%s'. Success: %s", filename, success) return image def hash_image_file(filename): """ Return an image file's sha1 hash """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name img = cv2_read_img(filename, raise_error=True) img_hash = sha1(img).hexdigest() logger.trace("filename: '%s', hash: %s", filename, img_hash) return img_hash def hash_encode_image(image, extension): """ Encode the image, get the hash and return the hash with encoded image """ img = cv2.imencode(extension, image)[1] # pylint:disable=no-member,c-extension-no-member f_hash = sha1( cv2.imdecode( # pylint:disable=no-member,c-extension-no-member img, cv2.IMREAD_UNCHANGED)).hexdigest() # pylint:disable=no-member,c-extension-no-member return f_hash, img def convert_to_secs(*args): """ converts a time to second. Either convert_to_secs(min, secs) or convert_to_secs(hours, mins, secs). """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.debug("from time: %s", args) retval = 0.0 if len(args) == 1: retval = float(args[0]) elif len(args) == 2: retval = 60 * float(args[0]) + float(args[1]) elif len(args) == 3: retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2]) logger.debug("to secs: %s", retval) return retval def count_frames_and_secs(path, timeout=15): """ Adapted From ffmpeg_imageio, to handle occasional hanging issue: https://github.com/imageio/imageio-ffmpeg Get the number of frames and number of seconds for the given video file. Note that this operation can be quite slow for large files. Disclaimer: I've seen this produce different results from actually reading the frames with older versions of ffmpeg (2.x). Therefore I cannot say with 100% certainty that the returned values are always exact. """ # https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg logger = logging.getLogger(__name__) # pylint:disable=invalid-name assert isinstance(path, str), "Video path must be a string" exe = im_ffm.get_ffmpeg_exe() iswin = sys.platform.startswith("win") logger.debug("iswin: '%s'", iswin) cmd = [exe, "-i", path, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"] logger.debug("FFMPEG Command: '%s'", " ".join(cmd)) attempts = 3 for attempt in range(attempts): try: logger.debug("attempt: %s of %s", attempt + 1, attempts) proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=iswin) out, _ = proc.communicate(timeout=timeout) logger.debug("Succesfully communicated with FFMPEG") break except subprocess.CalledProcessError as err: out = err.output.decode(errors="ignore") raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out)) except subprocess.TimeoutExpired as err: proc.kill() this_attempt = attempt + 1 if this_attempt == attempts: msg = ("FFMPEG hung while attempting to obtain the frame count. " "Sometimes this issue resolves itself, so you can try running again. " "Otherwise use the Effmpeg Tool to extract the frames from your video into " "a folder, and then run the requested Faceswap process on that folder.") raise FaceswapError(msg) from err logger.warning("FFMPEG hung while attempting to obtain the frame count. " "Retrying %s of %s", this_attempt + 1, attempts) continue # Note that other than with the subprocess calls below, ffmpeg wont hang here. # Worst case Python will stop/crash and ffmpeg will continue running until done. nframes = nsecs = None for line in out.splitlines(): if not line.startswith(b"frame="): continue line = line.decode(errors="ignore") logger.debug("frame line: '%s'", line) idx = line.find("frame=") if idx >= 0: splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() nframes = int(splitframes) idx = line.find("time=") if idx >= 0: splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip() nsecs = convert_to_secs(*splittime.split(":")) logger.debug("nframes: %s, nsecs: %s", nframes, nsecs) return nframes, nsecs raise RuntimeError("Could not get number of frames") # pragma: no cover def backup_file(directory, filename): """ Backup a given file by appending .bk to the end """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.trace("Backing up: '%s'", filename) origfile = os.path.join(directory, filename) backupfile = origfile + '.bk' if os.path.exists(backupfile): logger.trace("Removing existing file: '%s'", backup_file) os.remove(backupfile) if os.path.exists(origfile): logger.trace("Renaming: '%s' to '%s'", origfile, backup_file) os.rename(origfile, backupfile) def keras_backend_quiet(): """ Suppresses the "Using x backend" message when importing backend from keras """ stderr = sys.stderr sys.stderr = open(os.devnull, 'w') from keras import backend as K sys.stderr = stderr return K def set_system_verbosity(loglevel): """ Set the verbosity level of tensorflow and suppresses future and deprecation warnings from any modules From: https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information Can be set to: 0 - all logs shown 1 - filter out INFO logs 2 - filter out WARNING logs 3 - filter out ERROR logs """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name from lib.logger import get_loglevel numeric_level = get_loglevel(loglevel) loglevel = "2" if numeric_level > 15 else "0" logger.debug("System Verbosity level: %s", loglevel) os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel if loglevel != '0': for warncat in (FutureWarning, DeprecationWarning, UserWarning): warnings.simplefilter(action='ignore', category=warncat) def deprecation_warning(func_name, additional_info=None): """ Log at warning level that a function will be removed in future """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.debug("func_name: %s, additional_info: %s", func_name, additional_info) msg = "{} has been deprecated and will be removed from a future update.".format(func_name) if additional_info is not None: msg += " {}".format(additional_info) logger.warning(msg) def rotate_landmarks(face, rotation_matrix): # pylint:disable=c-extension-no-member """ Rotate the landmarks and bounding box for faces found in rotated images. Pass in a DetectedFace object, Alignments dict or bounding box dict (as defined in lib/plugins/extract/detect/_base.py) """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s", rotation_matrix, type(face)) # Detected Face Object if isinstance(face, DetectedFace): bounding_box = [[face.x, face.y], [face.x + face.w, face.y], [face.x + face.w, face.y + face.h], [face.x, face.y + face.h]] landmarks = face.landmarksXY # Alignments Dict elif isinstance(face, dict) and "x" in face: bounding_box = [[face.get("x", 0), face.get("y", 0)], [face.get("x", 0) + face.get("w", 0), face.get("y", 0)], [face.get("x", 0) + face.get("w", 0), face.get("y", 0) + face.get("h", 0)], [face.get("x", 0), face.get("y", 0) + face.get("h", 0)]] landmarks = face.get("landmarksXY", list()) # Bounding Box Dict elif isinstance(face, dict) and "left" in face: bounding_box = [[face["left"], face["top"]], [face["right"], face["top"]], [face["right"], face["bottom"]], [face["left"], face["bottom"]]] landmarks = list() else: raise ValueError("Unsupported face type") logger.trace("Original landmarks: %s", landmarks) rotation_matrix = cv2.invertAffineTransform( # pylint:disable=no-member rotation_matrix) rotated = list() for item in (bounding_box, landmarks): if not item: continue points = np.array(item, np.int32) points = np.expand_dims(points, axis=0) transformed = cv2.transform(points, # pylint:disable=no-member rotation_matrix).astype(np.int32) rotated.append(transformed.squeeze()) # Bounding box should follow x, y planes, so get min/max # for non-90 degree rotations pt_x = min([pnt[0] for pnt in rotated[0]]) pt_y = min([pnt[1] for pnt in rotated[0]]) pt_x1 = max([pnt[0] for pnt in rotated[0]]) pt_y1 = max([pnt[1] for pnt in rotated[0]]) width = pt_x1 - pt_x height = pt_y1 - pt_y if isinstance(face, DetectedFace): face.x = int(pt_x) face.y = int(pt_y) face.w = int(width) face.h = int(height) face.r = 0 if len(rotated) > 1: rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] face.landmarksXY = rotated_landmarks elif isinstance(face, dict) and "x" in face: face["x"] = int(pt_x) face["y"] = int(pt_y) face["w"] = int(width) face["h"] = int(height) face["r"] = 0 if len(rotated) > 1: rotated_landmarks = [tuple(point) for point in rotated[1].tolist()] face["landmarksXY"] = rotated_landmarks else: face["left"] = int(pt_x) face["top"] = int(pt_y) face["right"] = int(pt_x1) face["bottom"] = int(pt_y1) rotated_landmarks = face logger.trace("Rotated landmarks: %s", rotated_landmarks) return face def camel_case_split(identifier): """ Split a camel case name from: https://stackoverflow.com/questions/29916065 """ matches = finditer( ".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier) return [m.group(0) for m in matches] def safe_shutdown(): """ Close queues, threads and processes in event of crash """ logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger.debug("Safely shutting down") from lib.queue_manager import queue_manager from lib.multithreading import terminate_processes queue_manager.terminate_queues() terminate_processes() logger.debug("Cleanup complete. Shutting down queue manager and exiting") queue_manager._log_queue.put(None) # pylint:disable=protected-access while not queue_manager._log_queue.empty(): # pylint:disable=protected-access continue queue_manager.manager.shutdown() class FaceswapError(Exception): """ Faceswap Error for handling specific errors with useful information """ pass # pylint:disable=unnecessary-pass class GetModel(): """ Check for models in their cache path If available, return the path, if not available, get, unzip and install model model_filename: The name of the model to be loaded (see notes below) cache_dir: The model cache folder of the current plugin calling this class IE: The folder that holds the model to be loaded. git_model_id: The second digit in the github tag that identifies this model. See https://github.com/deepfakes-models/faceswap-models for more information NB: Models must have a certain naming convention: IE: _v. EG: s3fd_v1.pb Multiple models can exist within the model_filename. They should be passed as a list and follow the same naming convention as above. Any differences in filename should occur AFTER the version number. IE: [_v.] EG: [mtcnn_det_v1.1.py, mtcnn_det_v1.2.py, mtcnn_det_v1.3.py] [resnet_ssd_v1.caffemodel, resnet_ssd_v1.prototext] """ def __init__(self, model_filename, cache_dir, git_model_id): self.logger = logging.getLogger(__name__) # pylint:disable=invalid-name if not isinstance(model_filename, list): model_filename = [model_filename] self.model_filename = model_filename self.cache_dir = cache_dir self.git_model_id = git_model_id self.url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download" self.chunk_size = 1024 # Chunk size for downloading and unzipping self.retries = 6 self.get() self.model_path = self._model_path @property def _model_full_name(self): """ Return the model full name from the filename(s) """ common_prefix = os.path.commonprefix(self.model_filename) retval = os.path.splitext(common_prefix)[0] self.logger.trace(retval) return retval @property def _model_name(self): """ Return the model name from the model full name """ retval = self._model_full_name[:self._model_full_name.rfind("_")] self.logger.trace(retval) return retval @property def _model_version(self): """ Return the model version from the model full name """ retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:]) self.logger.trace(retval) return retval @property def _model_path(self): """ Return the model path(s) in the cache folder """ retval = [os.path.join(self.cache_dir, fname) for fname in self.model_filename] retval = retval[0] if len(retval) == 1 else retval self.logger.trace(retval) return retval @property def _model_zip_path(self): """ Full path to downloaded zip file """ retval = os.path.join(self.cache_dir, "{}.zip".format(self._model_full_name)) self.logger.trace(retval) return retval @property def _model_exists(self): """ Check model(s) exist """ if isinstance(self._model_path, list): retval = all(os.path.exists(pth) for pth in self._model_path) else: retval = os.path.exists(self._model_path) self.logger.trace(retval) return retval @property def _plugin_section(self): """ Get the plugin section from the config_dir """ path = os.path.normpath(self.cache_dir) split = path.split(os.sep) retval = split[split.index("plugins") + 1] self.logger.trace(retval) return retval @property def _url_section(self): """ Return the section ID in github for this plugin type """ sections = dict(extract=1, train=2, convert=3) retval = sections[self._plugin_section] self.logger.trace(retval) return retval @property def _url_download(self): """ Base URL for models """ tag = "v{}.{}.{}".format(self._url_section, self.git_model_id, self._model_version) retval = "{}/{}/{}.zip".format(self.url_base, tag, self._model_full_name) self.logger.trace("Download url: %s", retval) return retval @property def _url_partial_size(self): """ Return how many bytes have already been downloaded """ zip_file = self._model_zip_path retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0 self.logger.trace(retval) return retval def get(self): """ Check the model exists, if not, download and unzip into location """ if self._model_exists: self.logger.debug("Model exists: %s", self._model_path) return self.download_model() self.unzip_model() os.remove(self._model_zip_path) def download_model(self): """ Download model zip to cache dir """ self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download) for attempt in range(self.retries): try: downloaded_size = self._url_partial_size req = urllib.request.Request(self._url_download) if downloaded_size != 0: req.add_header("Range", "bytes={}-".format(downloaded_size)) response = urllib.request.urlopen(req, timeout=10) self.logger.debug("header info: {%s}", response.info()) self.logger.debug("Return Code: %s", response.getcode()) self.write_zipfile(response, downloaded_size) break except (socket_error, socket_timeout, urllib.error.HTTPError, urllib.error.URLError) as err: if attempt + 1 < self.retries: self.logger.warning("Error downloading model (%s). Retrying %s of %s...", str(err), attempt + 2, self.retries) else: self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: " "'%s')", str(err), self._url_download) self.logger.info("You can try running again to resume the download.") self.logger.info("Alternatively, you can manually download the model from: %s " "and unzip the contents to: %s", self._url_download, self.cache_dir) exit(1) def write_zipfile(self, response, downloaded_size): """ Write the model zip file to disk """ length = int(response.getheader("content-length")) + downloaded_size if length == downloaded_size: self.logger.info("Zip already exists. Skipping download") return write_type = "wb" if downloaded_size == 0 else "ab" with open(self._model_zip_path, write_type) as out_file: pbar = tqdm(desc="Downloading", unit="B", total=length, unit_scale=True, unit_divisor=1024) if downloaded_size != 0: pbar.update(downloaded_size) while True: buffer = response.read(self.chunk_size) if not buffer: break pbar.update(len(buffer)) out_file.write(buffer) def unzip_model(self): """ Unzip the model file to the cachedir """ self.logger.info("Extracting: '%s'", self._model_name) try: zip_file = zipfile.ZipFile(self._model_zip_path, "r") self.write_model(zip_file) except Exception as err: # pylint:disable=broad-except self.logger.error("Unable to extract model file: %s", str(err)) exit(1) def write_model(self, zip_file): """ Extract files from zipfile and write, with progress bar """ length = sum(f.file_size for f in zip_file.infolist()) fnames = zip_file.namelist() self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length) pbar = tqdm(desc="Decompressing", unit="B", total=length, unit_scale=True, unit_divisor=1024) for fname in fnames: out_fname = os.path.join(self.cache_dir, fname) self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname) zipped = zip_file.open(fname) with open(out_fname, "wb") as out_file: while True: buffer = zipped.read(self.chunk_size) if not buffer: break pbar.update(len(buffer)) out_file.write(buffer) zip_file.close()