1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00

Optimize Data Augmentation (#881)

* Move image utils to lib.image
* Add .pylintrc file
* Remove some cv2 pylint ignores
* TrainingData: Load images from disk in batches
* TrainingData: get_landmarks to batch
* TrainingData: transform and flip to batches
* TrainingData: Optimize color augmentation
* TrainingData: Optimize target and random_warp
* TrainingData - Convert _get_closest_match for batching
* TrainingData: Warp To Landmarks optimized
* Save models to threadpoolexecutor
* Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings
* Masks: Set dtype and threshold for lib.masks based on input face
* Docstrings and Documentation
This commit is contained in:
torzdf 2019-09-24 12:16:05 +01:00 committed by GitHub
parent 78bd012a99
commit 66ed005ef3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1709 additions and 665 deletions

1
.gitignore vendored
View file

@ -28,6 +28,7 @@
!plugins/extract/*
!plugins/train/*
!plugins/convert/*
!.pylintrc
!tools
!tools/lib*
!_travis

570
.pylintrc Normal file
View file

@ -0,0 +1,570 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=cv2
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
#variable-rgx=
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[STRING]
# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception

7
docs/full/lib.image.rst Normal file
View file

@ -0,0 +1,7 @@
lib.image module
========================
.. automodule:: lib.image
:members:
:undoc-members:
:show-inheritance:

View file

@ -8,6 +8,8 @@ Subpackages
lib.model
lib.faces_detect
lib.image
lib.training_data
Module contents
---------------

View file

@ -0,0 +1,7 @@
lib.training\_data module
=========================
.. automodule:: lib.training_data
:members:
:undoc-members:
:show-inheritance:

View file

@ -7,7 +7,7 @@ faceswap.dev Developer Documentation
====================================
.. toctree::
:maxdepth: 4
:maxdepth: 2
:caption: Contents:
full/modules

View file

@ -8,8 +8,8 @@ from datetime import datetime
import cv2
from lib.faces_detect import rotate_landmarks
from lib import Serializer
from lib.utils import rotate_landmarks
logger = logging.getLogger(__name__) # pylint: disable=invalid-name

View file

@ -4,7 +4,7 @@
import logging
from lib.vgg_face import VGGFace
from lib.utils import cv2_read_img
from lib.image import read_image
from plugins.extract.pipeline import Extractor
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -47,10 +47,10 @@ class FaceFilter():
""" Load the images """
retval = dict()
for fpath in reference_file_paths:
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
retval[fpath] = {"image": read_image(fpath, raise_error=True),
"type": "filter"}
for fpath in nreference_file_paths:
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
retval[fpath] = {"image": read_image(fpath, raise_error=True),
"type": "nfilter"}
logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()})
return retval

View file

@ -2,6 +2,7 @@
""" Face and landmarks detection for faceswap.py """
import logging
import cv2
import numpy as np
from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling
@ -399,3 +400,89 @@ class DetectedFace():
if not self.reference:
return None
return get_matrix_scaling(self.reference_matrix)
def rotate_landmarks(face, rotation_matrix):
""" Rotates the 68 point landmarks and detection bounding box around the given rotation matrix.
Paramaters
----------
face: DetectedFace or dict
A :class:`DetectedFace` or an `alignments file` ``dict`` containing the 68 point landmarks
and the `x`, `w`, `y`, `h` detection bounding box points.
rotation_matrix: numpy.ndarray
The rotation matrix to rotate the given object by.
Returns
-------
DetectedFace or dict
The rotated :class:`DetectedFace` or `alignments file` ``dict`` with the landmarks and
detection bounding box points rotated by the given matrix. The return type is the same as
the input type for ``face``
"""
logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s",
rotation_matrix, type(face))
rotated_landmarks = None
# Detected Face Object
if isinstance(face, DetectedFace):
bounding_box = [[face.x, face.y],
[face.x + face.w, face.y],
[face.x + face.w, face.y + face.h],
[face.x, face.y + face.h]]
landmarks = face.landmarks_xy
# Alignments Dict
elif isinstance(face, dict) and "x" in face:
bounding_box = [[face.get("x", 0), face.get("y", 0)],
[face.get("x", 0) + face.get("w", 0),
face.get("y", 0)],
[face.get("x", 0) + face.get("w", 0),
face.get("y", 0) + face.get("h", 0)],
[face.get("x", 0),
face.get("y", 0) + face.get("h", 0)]]
landmarks = face.get("landmarks_xy", list())
else:
raise ValueError("Unsupported face type")
logger.trace("Original landmarks: %s", landmarks)
rotation_matrix = cv2.invertAffineTransform(
rotation_matrix)
rotated = list()
for item in (bounding_box, landmarks):
if not item:
continue
points = np.array(item, np.int32)
points = np.expand_dims(points, axis=0)
transformed = cv2.transform(points,
rotation_matrix).astype(np.int32)
rotated.append(transformed.squeeze())
# Bounding box should follow x, y planes, so get min/max
# for non-90 degree rotations
pt_x = min([pnt[0] for pnt in rotated[0]])
pt_y = min([pnt[1] for pnt in rotated[0]])
pt_x1 = max([pnt[0] for pnt in rotated[0]])
pt_y1 = max([pnt[1] for pnt in rotated[0]])
width = pt_x1 - pt_x
height = pt_y1 - pt_y
if isinstance(face, DetectedFace):
face.x = int(pt_x)
face.y = int(pt_y)
face.w = int(width)
face.h = int(height)
face.r = 0
if len(rotated) > 1:
rotated_landmarks = [tuple(point) for point in rotated[1].tolist()]
face.landmarks_xy = rotated_landmarks
else:
face["left"] = int(pt_x)
face["top"] = int(pt_y)
face["right"] = int(pt_x1)
face["bottom"] = int(pt_y1)
rotated_landmarks = face
logger.trace("Rotated landmarks: %s", rotated_landmarks)
return face

302
lib/image.py Normal file
View file

@ -0,0 +1,302 @@
#!/usr/bin python3
""" Utilities for working with images and videos """
import logging
import subprocess
import sys
from concurrent import futures
from hashlib import sha1
import cv2
import imageio_ffmpeg as im_ffm
import numpy as np
from lib.utils import convert_to_secs, FaceswapError
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
# ################### #
# <<< IMAGE UTILS >>> #
# ################### #
# <<< IMAGE IO >>> #
def read_image(filename, raise_error=False):
""" Read an image file from a file location.
Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually
loaded. Errors can be logged and ignored so that the process can continue on an image load
failure.
Parameters
----------
filename: str
Full path to the image to be loaded.
raise_error: bool, optional
If ``True``, then any failures (including the returned image being ``None``) will be
raised. If ``False`` then an error message will be logged, but the error will not be
raised. Default: ``False``
Returns
-------
numpy.ndarray
The image in `BGR` channel order.
Example
-------
>>> image_file = "/path/to/image.png"
>>> try:
>>> image = read_image(image_file, raise_error=True)
>>> except:
>>> raise ValueError("There was an error")
"""
logger.trace("Requested image: '%s'", filename)
success = True
image = None
try:
image = cv2.imread(filename)
if image is None:
raise ValueError
except TypeError:
success = False
msg = "Error while reading image (TypeError): '{}'".format(filename)
logger.error(msg)
if raise_error:
raise Exception(msg)
except ValueError:
success = False
msg = ("Error while reading image. This is most likely caused by special characters in "
"the filename: '{}'".format(filename))
logger.error(msg)
if raise_error:
raise Exception(msg)
except Exception as err: # pylint:disable=broad-except
success = False
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
return image
def read_image_batch(filenames):
""" Load a batch of images from the given file locations.
Leverages multi-threading to load multiple images from disk at the same time
leading to vastly reduced image read times.
Parameters
----------
filenames: list
A list of ``str`` full paths to the images to be loaded.
Returns
-------
numpy.ndarray
The batch of images in `BGR` channel order.
Notes
-----
As the images are compiled into a batch, they must be all of the same dimensions.
Example
-------
>>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"]
>>> images = read_image_batch(image_filenames)
"""
logger.trace("Requested batch: '%s'", filenames)
executor = futures.ThreadPoolExecutor()
with executor:
images = [executor.submit(read_image, filename, raise_error=True)
for filename in filenames]
batch = np.array([future.result() for future in futures.as_completed(images)])
logger.trace("Returning images: %s", batch.shape)
return batch
def read_image_hash(filename):
""" Return the `sha1` hash of an image saved on disk.
Parameters
----------
filename: str
Full path to the image to be loaded.
Returns
-------
str
The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the given image.
Example
-------
>>> image_file = "/path/to/image.png"
>>> image_hash = read_image_hash(image_file)
"""
img = read_image(filename, raise_error=True)
image_hash = sha1(img).hexdigest()
logger.trace("filename: '%s', hash: %s", filename, image_hash)
return image_hash
def encode_image_with_hash(image, extension):
""" Encode an image, and get the encoded image back with its `sha1` hash.
Parameters
----------
image: numpy.ndarray
The image to be encoded in `BGR` channel order.
extension: str
A compatible `cv2` image file extension that the final image is to be saved to.
Returns
-------
image_hash: str
The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the encoded image
encoded_image: bytes
The image encoded into the correct file format
Example
-------
>>> image_file = "/path/to/image.png"
>>> image = read_image(image_file)
>>> image_hash, encoded_image = encode_image_with_hash(image, ".jpg")
"""
encoded_image = cv2.imencode(extension, image)[1]
image_hash = sha1(cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED)).hexdigest()
return image_hash, encoded_image
def batch_convert_color(batch, colorspace):
""" Convert a batch of images from one color space to another.
Converts a batch of images by reshaping the batch prior to conversion rather than iterating
over the images. This leads to a significant speed up in the convert process.
Parameters
----------
batch: numpy.ndarray
A batch of images.
colorspace: str
The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be
``'BGR2LAB'``.
See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full
list of color codes.
Returns
-------
numpy.ndarray
The batch converted to the requested color space.
Example
-------
>>> images_bgr = numpy.array([image1, image2, image3])
>>> images_lab = batch_convert_color(images_bgr, "BGR2LAB")
Notes
-----
This function is only compatible for color space conversions that have the same image shape
for source and destination color spaces.
If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some
information lost. For many cases, this will not be noticeable but it is recommended
to use 32-bit images in cases that need the full range of colors or that convert an image
before an operation and then convert back.
"""
logger.trace("Batch converting: (batch shape: %s, colorspace: %s)", batch.shape, colorspace)
original_shape = batch.shape
batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:]))
batch = cv2.cvtColor(batch, getattr(cv2, "COLOR_{}".format(colorspace)))
return batch.reshape(original_shape)
# ################### #
# <<< VIDEO UTILS >>> #
# ################### #
def count_frames_and_secs(filename, timeout=60):
""" Count the number of frames and seconds in a video file.
Adapted From :mod:`ffmpeg_imageio` to handle the issue of ffmpeg occasionally hanging
inside a subprocess.
If the operation times out then the process will try to read the data again, up to a total
of 3 times. If the data still cannot be read then an exception will be raised.
Note that this operation can be quite slow for large files.
Parameters
----------
filename: str
Full path to the video to be analyzed.
timeout: str, optional
The amount of time in seconds to wait for the video data before aborting.
Default: ``60``
Returns
-------
nframes: int
The number of frames in the given video file.
nsecs: float
The duration, in seconds, of the given video file.
Example
-------
>>> video = "/path/to/video.mp4"
>>> frames, secs = count_frames_and_secs(video)
"""
# https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg
assert isinstance(filename, str), "Video path must be a string"
exe = im_ffm.get_ffmpeg_exe()
iswin = sys.platform.startswith("win")
logger.debug("iswin: '%s'", iswin)
cmd = [exe, "-i", filename, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"]
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
attempts = 3
for attempt in range(attempts):
try:
logger.debug("attempt: %s of %s", attempt + 1, attempts)
out = subprocess.check_output(cmd,
stderr=subprocess.STDOUT,
shell=iswin,
timeout=timeout)
logger.debug("Succesfully communicated with FFMPEG")
break
except subprocess.CalledProcessError as err:
out = err.output.decode(errors="ignore")
raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out))
except subprocess.TimeoutExpired as err:
this_attempt = attempt + 1
if this_attempt == attempts:
msg = ("FFMPEG hung while attempting to obtain the frame count. "
"Sometimes this issue resolves itself, so you can try running again. "
"Otherwise use the Effmpeg Tool to extract the frames from your video into "
"a folder, and then run the requested Faceswap process on that folder.")
raise FaceswapError(msg) from err
logger.warning("FFMPEG hung while attempting to obtain the frame count. "
"Retrying %s of %s", this_attempt + 1, attempts)
continue
# Note that other than with the subprocess calls below, ffmpeg wont hang here.
# Worst case Python will stop/crash and ffmpeg will continue running until done.
nframes = nsecs = None
for line in reversed(out.splitlines()):
if not line.startswith(b"frame="):
continue
line = line.decode(errors="ignore")
logger.debug("frame line: '%s'", line)
idx = line.find("frame=")
if idx >= 0:
splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
nframes = int(splitframes)
idx = line.find("time=")
if idx >= 0:
splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
nsecs = convert_to_secs(*splittime.split(":"))
logger.debug("nframes: %s, nsecs: %s", nframes, nsecs)
return nframes, nsecs
raise RuntimeError("Could not get number of frames") # pragma: no cover

View file

@ -43,6 +43,8 @@ class Mask():
self.__class__.__name__, face.shape, channels, landmarks)
self.landmarks = landmarks
self.face = face
self.dtype = face.dtype
self.threshold = 255 if self.dtype == "uint8" else 255.0
self.channels = channels
mask = self.build_mask()
@ -73,7 +75,7 @@ class Mask():
class dfl_full(Mask): # pylint: disable=invalid-name
""" DFL facial mask """
def build_mask(self):
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
nose_ridge = (self.landmarks[27:31], self.landmarks[33:34])
jaw = (self.landmarks[0:17],
@ -90,14 +92,14 @@ class dfl_full(Mask): # pylint: disable=invalid-name
for item in parts:
merged = np.concatenate(item)
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
return mask
class components(Mask): # pylint: disable=invalid-name
""" Component model mask """
def build_mask(self):
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
r_jaw = (self.landmarks[0:9], self.landmarks[17:18])
l_jaw = (self.landmarks[8:17], self.landmarks[26:27])
@ -117,7 +119,7 @@ class components(Mask): # pylint: disable=invalid-name
for item in parts:
merged = np.concatenate(item)
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
return mask
@ -126,7 +128,7 @@ class extended(Mask): # pylint: disable=invalid-name
Based on components mask. Attempts to extend the eyebrow points up the forehead
"""
def build_mask(self):
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
landmarks = self.landmarks.copy()
# mid points between the side of face and eye point
@ -161,15 +163,15 @@ class extended(Mask): # pylint: disable=invalid-name
for item in parts:
merged = np.concatenate(item)
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
return mask
class facehull(Mask): # pylint: disable=invalid-name
""" Basic face hull mask """
def build_mask(self):
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
hull = cv2.convexHull( # pylint: disable=no-member
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
hull = cv2.convexHull(
np.array(self.landmarks).reshape((-1, 2)))
cv2.fillConvexPoly(mask, hull, 255.0, lineType=cv2.LINE_AA) # pylint: disable=no-member
cv2.fillConvexPoly(mask, hull, self.threshold, lineType=cv2.LINE_AA)
return mask

File diff suppressed because it is too large Load diff

View file

@ -4,27 +4,18 @@
import json
import logging
import os
import subprocess
import sys
import urllib
import warnings
import zipfile
from hashlib import sha1
from pathlib import Path
from re import finditer
from multiprocessing import current_process
from socket import timeout as socket_timeout, error as socket_error
import imageio_ffmpeg as im_ffm
from tqdm import tqdm
import numpy as np
import cv2
from lib.faces_detect import DetectedFace
# Global variables
_image_extensions = [ # pylint:disable=invalid-name
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
@ -132,6 +123,22 @@ def get_image_paths(directory):
return dir_contents
def convert_to_secs(*args):
""" converts a time to second. Either convert_to_secs(min, secs) or
convert_to_secs(hours, mins, secs). """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("from time: %s", args)
retval = 0.0
if len(args) == 1:
retval = float(args[0])
elif len(args) == 2:
retval = 60 * float(args[0]) + float(args[1])
elif len(args) == 3:
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
logger.debug("to secs: %s", retval)
return retval
def full_path_split(path):
""" Split a given path into all of it's separate components """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -151,147 +158,6 @@ def full_path_split(path):
return allparts
def cv2_read_img(filename, raise_error=False):
""" Read an image with cv2 and check that an image was actually loaded.
Logs an error if the image returned is None. or an error has occured.
Pass raise_error=True if error should be raised """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.trace("Requested image: '%s'", filename)
success = True
image = None
try:
image = cv2.imread(filename) # pylint:disable=no-member,c-extension-no-member
if image is None:
raise ValueError
except TypeError:
success = False
msg = "Error while reading image (TypeError): '{}'".format(filename)
logger.error(msg)
if raise_error:
raise Exception(msg)
except ValueError:
success = False
msg = ("Error while reading image. This is most likely caused by special characters in "
"the filename: '{}'".format(filename))
logger.error(msg)
if raise_error:
raise Exception(msg)
except Exception as err: # pylint:disable=broad-except
success = False
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
return image
def hash_image_file(filename):
""" Return an image file's sha1 hash """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
img = cv2_read_img(filename, raise_error=True)
img_hash = sha1(img).hexdigest()
logger.trace("filename: '%s', hash: %s", filename, img_hash)
return img_hash
def hash_encode_image(image, extension):
""" Encode the image, get the hash and return the hash with
encoded image """
img = cv2.imencode(extension, image)[1] # pylint:disable=no-member,c-extension-no-member
f_hash = sha1(
cv2.imdecode( # pylint:disable=no-member,c-extension-no-member
img,
cv2.IMREAD_UNCHANGED)).hexdigest() # pylint:disable=no-member,c-extension-no-member
return f_hash, img
def convert_to_secs(*args):
""" converts a time to second. Either convert_to_secs(min, secs) or
convert_to_secs(hours, mins, secs). """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("from time: %s", args)
retval = 0.0
if len(args) == 1:
retval = float(args[0])
elif len(args) == 2:
retval = 60 * float(args[0]) + float(args[1])
elif len(args) == 3:
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
logger.debug("to secs: %s", retval)
return retval
def count_frames_and_secs(path, timeout=60):
"""
Adapted From ffmpeg_imageio, to handle occasional hanging issue:
https://github.com/imageio/imageio-ffmpeg
Get the number of frames and number of seconds for the given video
file. Note that this operation can be quite slow for large files.
Disclaimer: I've seen this produce different results from actually reading
the frames with older versions of ffmpeg (2.x). Therefore I cannot say
with 100% certainty that the returned values are always exact.
"""
# https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
assert isinstance(path, str), "Video path must be a string"
exe = im_ffm.get_ffmpeg_exe()
iswin = sys.platform.startswith("win")
logger.debug("iswin: '%s'", iswin)
cmd = [exe, "-i", path, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"]
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
attempts = 3
for attempt in range(attempts):
try:
logger.debug("attempt: %s of %s", attempt + 1, attempts)
out = subprocess.check_output(cmd,
stderr=subprocess.STDOUT,
shell=iswin,
timeout=timeout)
logger.debug("Succesfully communicated with FFMPEG")
break
except subprocess.CalledProcessError as err:
out = err.output.decode(errors="ignore")
raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out))
except subprocess.TimeoutExpired as err:
this_attempt = attempt + 1
if this_attempt == attempts:
msg = ("FFMPEG hung while attempting to obtain the frame count. "
"Sometimes this issue resolves itself, so you can try running again. "
"Otherwise use the Effmpeg Tool to extract the frames from your video into "
"a folder, and then run the requested Faceswap process on that folder.")
raise FaceswapError(msg) from err
logger.warning("FFMPEG hung while attempting to obtain the frame count. "
"Retrying %s of %s", this_attempt + 1, attempts)
continue
# Note that other than with the subprocess calls below, ffmpeg wont hang here.
# Worst case Python will stop/crash and ffmpeg will continue running until done.
nframes = nsecs = None
for line in reversed(out.splitlines()):
if not line.startswith(b"frame="):
continue
line = line.decode(errors="ignore")
logger.debug("frame line: '%s'", line)
idx = line.find("frame=")
if idx >= 0:
splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
nframes = int(splitframes)
idx = line.find("time=")
if idx >= 0:
splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
nsecs = convert_to_secs(*splittime.split(":"))
logger.debug("nframes: %s, nsecs: %s", nframes, nsecs)
return nframes, nsecs
raise RuntimeError("Could not get number of frames") # pragma: no cover
def backup_file(directory, filename):
""" Backup a given file by appending .bk to the end """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -348,80 +214,6 @@ def deprecation_warning(func_name, additional_info=None):
logger.warning(msg)
def rotate_landmarks(face, rotation_matrix):
# pylint:disable=c-extension-no-member
""" Rotate the landmarks and bounding box for faces
found in rotated images.
Pass in a DetectedFace object or Alignments dict """
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s",
rotation_matrix, type(face))
rotated_landmarks = None
# Detected Face Object
if isinstance(face, DetectedFace):
bounding_box = [[face.x, face.y],
[face.x + face.w, face.y],
[face.x + face.w, face.y + face.h],
[face.x, face.y + face.h]]
landmarks = face.landmarks_xy
# Alignments Dict
elif isinstance(face, dict) and "x" in face:
bounding_box = [[face.get("x", 0), face.get("y", 0)],
[face.get("x", 0) + face.get("w", 0),
face.get("y", 0)],
[face.get("x", 0) + face.get("w", 0),
face.get("y", 0) + face.get("h", 0)],
[face.get("x", 0),
face.get("y", 0) + face.get("h", 0)]]
landmarks = face.get("landmarks_xy", list())
else:
raise ValueError("Unsupported face type")
logger.trace("Original landmarks: %s", landmarks)
rotation_matrix = cv2.invertAffineTransform( # pylint:disable=no-member
rotation_matrix)
rotated = list()
for item in (bounding_box, landmarks):
if not item:
continue
points = np.array(item, np.int32)
points = np.expand_dims(points, axis=0)
transformed = cv2.transform(points, # pylint:disable=no-member
rotation_matrix).astype(np.int32)
rotated.append(transformed.squeeze())
# Bounding box should follow x, y planes, so get min/max
# for non-90 degree rotations
pt_x = min([pnt[0] for pnt in rotated[0]])
pt_y = min([pnt[1] for pnt in rotated[0]])
pt_x1 = max([pnt[0] for pnt in rotated[0]])
pt_y1 = max([pnt[1] for pnt in rotated[0]])
width = pt_x1 - pt_x
height = pt_y1 - pt_y
if isinstance(face, DetectedFace):
face.x = int(pt_x)
face.y = int(pt_y)
face.w = int(width)
face.h = int(height)
face.r = 0
if len(rotated) > 1:
rotated_landmarks = [tuple(point) for point in rotated[1].tolist()]
face.landmarks_xy = rotated_landmarks
else:
face["left"] = int(pt_x)
face["top"] = int(pt_y)
face["right"] = int(pt_x1)
face["bottom"] = int(pt_y1)
rotated_landmarks = face
logger.trace("Rotated landmarks: %s", rotated_landmarks)
return face
def camel_case_split(identifier):
""" Split a camel case name
from: https://stackoverflow.com/questions/29916065 """

View file

@ -18,8 +18,7 @@ To get a :class:`~lib.faces_detect.DetectedFace` object use the function:
import cv2
import numpy as np
from lib.faces_detect import DetectedFace
from lib.utils import rotate_landmarks
from lib.faces_detect import DetectedFace, rotate_landmarks
from plugins.extract._base import Extractor, logger

View file

@ -9,6 +9,7 @@ import os
import sys
import time
from concurrent import futures
from json import JSONDecodeError
import keras
@ -24,7 +25,6 @@ from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask
generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur)
from lib.model.nn_blocks import NNBlocks
from lib.model.optimizers import Adam
from lib.multithreading import MultiThread
from lib.utils import deprecation_warning, FaceswapError
from plugins.train._config import Config
@ -466,21 +466,13 @@ class ModelBase():
backup_func = self.backup.backup_model if self.should_backup(save_averages) else None
if backup_func:
logger.info("Backing up models...")
save_threads = list()
for network in self.networks.values():
name = "save_{}".format(network.name)
save_threads.append(MultiThread(network.save,
name=name,
backup_func=backup_func))
save_threads.append(MultiThread(self.state.save,
name="save_state",
backup_func=backup_func))
for thread in save_threads:
thread.start()
for thread in save_threads:
if thread.has_error:
logger.error(thread.errors[0])
thread.join()
executor = futures.ThreadPoolExecutor()
save_threads = [executor.submit(network.save, backup_func=backup_func)
for network in self.networks.values()]
save_threads.append(executor.submit(self.state.save, backup_func=backup_func))
futures.wait(save_threads)
# call result() to capture errors
_ = [thread.result() for thread in save_threads]
msg = "[Saved models]"
if save_averages:
lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0],

View file

@ -33,7 +33,7 @@ from tensorflow.python import errors_impl as tf_errors # pylint:disable=no-name
from lib.alignments import Alignments
from lib.faces_detect import DetectedFace
from lib.training_data import TrainingDataGenerator, stack_images
from lib.training_data import TrainingDataGenerator
from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.train._config import Config
@ -292,10 +292,10 @@ class Batcher():
""" Return the next batch from the generator
Items should come out as: (warped, target [, mask]) """
batch = next(self.feed)
feed = batch[1]
batch = batch[2:] # Remove full size samples and feed from batch
mask = batch[-1]
batch = [[feed, mask], batch] if self.use_mask else [feed, batch]
if self.use_mask:
batch = [[batch["feed"], batch["masks"]], batch["targets"] + [batch["masks"]]]
else:
batch = [batch["feed"], batch["targets"]]
self.generate_preview(do_preview)
return batch
@ -309,13 +309,10 @@ class Batcher():
if self.preview_feed is None:
self.set_preview_feed()
batch = next(self.preview_feed)
self.samples, feed = batch[:2]
batch = batch[2:] # Remove full size samples and feed from batch
self.target = batch[self.model.largest_face_index]
self.samples = batch["samples"]
self.target = [batch["targets"][self.model.largest_face_index]]
if self.use_mask:
mask = batch[-1]
batch = [[feed, mask], batch]
self.target = [self.target, mask]
self.target += [batch["masks"]]
def set_preview_feed(self):
""" Set the preview dictionary """
@ -347,15 +344,11 @@ class Batcher():
def compile_timelapse_sample(self):
""" Timelapse samples """
batch = next(self.timelapse_feed)
samples, feed = batch[:2]
batchsize = len(samples)
batch = batch[2:] # Remove full size samples and feed from batch
images = batch[self.model.largest_face_index]
batchsize = len(batch["samples"])
images = [batch["targets"][self.model.largest_face_index]]
if self.use_mask:
mask = batch[-1]
batch = [[feed, mask], batch]
images = [images, mask]
sample = self.compile_sample(batchsize, samples=samples, images=images)
images = images + [batch["masks"]]
sample = self.compile_sample(batchsize, samples=batch["samples"], images=images)
return sample
def set_timelapse_feed(self, images, batchsize):
@ -405,10 +398,10 @@ class Samples():
for side, samples in self.images.items():
other_side = "a" if side == "b" else "b"
predictions = [preds["{}_{}".format(side, side)],
predictions = [preds["{0}_{0}".format(side)],
preds["{}_{}".format(other_side, side)]]
display = self.to_full_frame(side, samples, predictions)
headers[side] = self.get_headers(side, other_side, display[0].shape[1])
headers[side] = self.get_headers(side, display[0].shape[1])
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
if self.images[side][0].shape[0] % 2 == 1:
figures[side] = np.concatenate([figures[side],
@ -547,22 +540,22 @@ class Samples():
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
return retval
def get_headers(self, side, other_side, width):
def get_headers(self, side, width):
""" Set headers for images """
logger.debug("side: '%s', other_side: '%s', width: %s",
side, other_side, width)
logger.debug("side: '%s', width: %s",
side, width)
titles = ("Original", "Swap") if side == "a" else ("Swap", "Original")
side = side.upper()
other_side = other_side.upper()
height = int(64 * self.scaling)
total_width = width * 3
logger.debug("height: %s, total_width: %s", height, total_width)
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
texts = ["Target {}".format(side),
"{} > {}".format(side, side),
"{} > {}".format(side, other_side)]
texts = ["{} ({})".format(titles[0], side),
"{0} > {0}".format(titles[0]),
"{} > {}".format(titles[0], titles[1])]
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
font,
self.scaling,
self.scaling * 0.8,
1)[0]
for idx in range(len(texts))]
text_y = int((height + text_sizes[0][1]) / 2)
@ -576,7 +569,7 @@ class Samples():
text,
(text_x[idx], text_y),
font,
self.scaling,
self.scaling * 0.8,
(0, 0, 0),
1,
lineType=cv2.LINE_AA) # pylint: disable=no-member
@ -703,3 +696,25 @@ class Landmarks():
detected_face.load_aligned(None, size=self.size)
landmarks[detected_face.hash] = detected_face.aligned_landmarks
return landmarks
def stack_images(images):
""" Stack images """
logger.debug("Stack images")
def get_transpose_axes(num):
if num % 2 == 0:
logger.debug("Even number of images to stack")
y_axes = list(range(1, num - 1, 2))
x_axes = list(range(0, num - 1, 2))
else:
logger.debug("Odd number of images to stack")
y_axes = list(range(0, num - 1, 2))
x_axes = list(range(1, num - 1, 2))
return y_axes, x_axes, [num - 1]
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
logger.debug("Stacked images")
return np.transpose(images, axes=np.concatenate(new_axes)).reshape(new_shape)

View file

@ -17,9 +17,10 @@ from lib import Serializer
from lib.convert import Converter
from lib.faces_detect import DetectedFace
from lib.gpu_stats import GPUStats
from lib.image import read_image_hash
from lib.multithreading import MultiThread, total_cpus
from lib.queue_manager import queue_manager
from lib.utils import FaceswapError, get_folder, get_image_paths, hash_image_file
from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.extract.pipeline import Extractor
from plugins.plugin_loader import PluginLoader
@ -682,7 +683,7 @@ class OptionalActions():
file_list = [path for path in get_image_paths(input_aligned_dir)]
logger.info("Getting Face Hashes for selected Aligned Images")
for face in tqdm(file_list, desc="Hashing Faces"):
face_hashes.append(hash_image_file(face))
face_hashes.append(read_image_hash(face))
logger.debug("Face Hashes: %s", (len(face_hashes)))
if not face_hashes:
raise FaceswapError("Aligned directory is empty, no faces will be converted!")
@ -746,5 +747,5 @@ class Legacy():
continue
hash_faces = all_faces[frame]
for index, face_path in hash_faces.items():
hash_faces[index] = hash_image_file(face_path)
hash_faces[index] = read_image_hash(face_path)
self.alignments.add_face_hashes(frame, hash_faces)

View file

@ -8,9 +8,10 @@ from pathlib import Path
from tqdm import tqdm
from lib.image import encode_image_with_hash
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import get_folder, hash_encode_image, deprecation_warning
from lib.utils import get_folder, deprecation_warning
from plugins.extract.pipeline import Extractor
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
@ -255,7 +256,7 @@ class Extract():
face = detected_face["face"]
resized_face = face.aligned_face
face.hash, img = hash_encode_image(resized_face, extension)
face.hash, img = encode_image_with_hash(resized_face, extension)
self.save_queue.put((out_filename, img))
final_faces.append(face.to_alignment())
self.alignments.data[os.path.basename(filename)] = final_faces

View file

@ -16,8 +16,9 @@ import numpy as np
from lib.aligner import Extract as AlignerExtract
from lib.alignments import Alignments as AlignmentsBase
from lib.face_filter import FaceFilter as FilterFunc
from lib.utils import (camel_case_split, count_frames_and_secs, cv2_read_img, get_folder,
get_image_paths, set_system_verbosity, _video_extensions)
from lib.image import count_frames_and_secs, read_image
from lib.utils import (camel_case_split, get_folder, get_image_paths, set_system_verbosity,
_video_extensions)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -183,7 +184,7 @@ class Images():
""" Load frames from disk """
logger.debug("Input is separate Frames. Loading images")
for filename in self.input_images:
image = cv2_read_img(filename, raise_error=False)
image = read_image(filename, raise_error=False)
if image is None:
continue
yield filename, image
@ -212,7 +213,7 @@ class Images():
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
retval = self.load_one_video_frame(int(frame_no))
else:
retval = cv2_read_img(filename, raise_error=True)
retval = read_image(filename, raise_error=True)
return retval
def load_one_video_frame(self, frame_no):

View file

@ -12,10 +12,11 @@ import cv2
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from lib.image import read_image
from lib.keypress import KBHit
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import cv2_read_img, get_folder, get_image_paths, set_system_verbosity
from lib.queue_manager import queue_manager # noqa pylint:disable=unused-import
from lib.utils import get_folder, get_image_paths, set_system_verbosity
from plugins.plugin_loader import PluginLoader
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -176,7 +177,7 @@ class Train():
@property
def image_size(self):
""" Get the training set image size for storing in model data """
image = cv2_read_img(self.images["a"][0], raise_error=True)
image = read_image(self.images["a"][0], raise_error=True)
size = image.shape[0]
logger.debug("Training image size: %s", size)
return size

View file

@ -14,8 +14,8 @@ from tqdm import tqdm
from lib.aligner import Extract as AlignerExtract
from lib.alignments import Alignments
from lib.faces_detect import DetectedFace
from lib.utils import (_image_extensions, _video_extensions, count_frames_and_secs, cv2_read_img,
hash_image_file, hash_encode_image)
from lib.image import count_frames_and_secs, encode_image_with_hash, read_image, read_image_hash
from lib.utils import _image_extensions, _video_extensions
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -175,7 +175,7 @@ class MediaLoader():
else:
src = os.path.join(self.folder, filename)
logger.trace("Loading image: '%s'", src)
image = cv2_read_img(src, raise_error=True)
image = read_image(src, raise_error=True)
return image
def load_video_frame(self, filename):
@ -210,7 +210,7 @@ class Faces(MediaLoader):
continue
filename = os.path.splitext(face)[0]
file_extension = os.path.splitext(face)[1]
face_hash = hash_image_file(os.path.join(self.folder, face))
face_hash = read_image_hash(os.path.join(self.folder, face))
retval = {"face_fullname": face,
"face_name": filename,
"face_extension": file_extension,
@ -358,7 +358,7 @@ class ExtractedFaces():
@staticmethod
def save_face_with_hash(filename, extension, face):
""" Save a face and return it's hash """
f_hash, img = hash_encode_image(face, extension)
f_hash, img = encode_image_with_hash(face, extension)
logger.trace("Saving face: '%s'", filename)
with open(filename, "wb") as out_file:
out_file.write(img)

View file

@ -16,8 +16,8 @@ from tqdm import tqdm
from lib.cli import FullHelpArgumentParser
from lib import Serializer
from lib.faces_detect import DetectedFace
from lib.image import read_image
from lib.queue_manager import queue_manager
from lib.utils import cv2_read_img
from lib.vgg_face2_keras import VGGFace2 as VGGFace
from plugins.plugin_loader import PluginLoader
@ -106,7 +106,7 @@ class Sort():
@staticmethod
def get_landmarks(filename):
""" Extract the face from a frame (If not alignments file found) """
image = cv2_read_img(filename, raise_error=True)
image = read_image(filename, raise_error=True)
feed = Sort.alignment_dict(image)
feed["filename"] = filename
queue_manager.get_queue("in").put(feed)
@ -161,7 +161,7 @@ class Sort():
logger.info("Sorting by face similarity...")
images = np.array(self.find_images(input_dir))
preds = np.array([self.vgg_face.predict(cv2_read_img(img, raise_error=True))
preds = np.array([self.vgg_face.predict(read_image(img, raise_error=True))
for img in tqdm(images, desc="loading", file=sys.stdout)])
logger.info("Sorting. Depending on ths size of your dataset, this may take a few "
"minutes...")
@ -264,7 +264,7 @@ class Sort():
logger.info("Sorting by histogram similarity...")
img_list = [
[img, cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
[img, cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])]
for img in
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
]
@ -294,7 +294,7 @@ class Sort():
img_list = [
[img,
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256]), 0]
cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256]), 0]
for img in
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
]
@ -548,7 +548,7 @@ class Sort():
input_dir = self.args.input_dir
logger.info("Preparing to group...")
if group_method == 'group_blur':
temp_list = [[img, self.estimate_blur(cv2_read_img(img, raise_error=True))]
temp_list = [[img, self.estimate_blur(read_image(img, raise_error=True))]
for img in
tqdm(self.find_images(input_dir),
desc="Reloading",
@ -576,7 +576,7 @@ class Sort():
elif group_method == 'group_hist':
temp_list = [
[img,
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])]
for img in
tqdm(self.find_images(input_dir),
desc="Reloading",
@ -632,7 +632,7 @@ class Sort():
Estimate the amount of blur an image has with the variance of the Laplacian.
Normalize by pixel number to offset the effect of image size on pixel gradients & variance
"""
image = cv2_read_img(image_file, raise_error=True)
image = read_image(image_file, raise_error=True)
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blur_map = cv2.Laplacian(image, cv2.CV_32F)