mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
Optimize Data Augmentation (#881)
* Move image utils to lib.image * Add .pylintrc file * Remove some cv2 pylint ignores * TrainingData: Load images from disk in batches * TrainingData: get_landmarks to batch * TrainingData: transform and flip to batches * TrainingData: Optimize color augmentation * TrainingData: Optimize target and random_warp * TrainingData - Convert _get_closest_match for batching * TrainingData: Warp To Landmarks optimized * Save models to threadpoolexecutor * Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings * Masks: Set dtype and threshold for lib.masks based on input face * Docstrings and Documentation
This commit is contained in:
parent
78bd012a99
commit
66ed005ef3
22 changed files with 1709 additions and 665 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -28,6 +28,7 @@
|
||||||
!plugins/extract/*
|
!plugins/extract/*
|
||||||
!plugins/train/*
|
!plugins/train/*
|
||||||
!plugins/convert/*
|
!plugins/convert/*
|
||||||
|
!.pylintrc
|
||||||
!tools
|
!tools
|
||||||
!tools/lib*
|
!tools/lib*
|
||||||
!_travis
|
!_travis
|
||||||
|
|
570
.pylintrc
Normal file
570
.pylintrc
Normal file
|
@ -0,0 +1,570 @@
|
||||||
|
[MASTER]
|
||||||
|
|
||||||
|
# A comma-separated list of package or module names from where C extensions may
|
||||||
|
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||||
|
# run arbitrary code.
|
||||||
|
extension-pkg-whitelist=cv2
|
||||||
|
|
||||||
|
# Add files or directories to the blacklist. They should be base names, not
|
||||||
|
# paths.
|
||||||
|
ignore=CVS
|
||||||
|
|
||||||
|
# Add files or directories matching the regex patterns to the blacklist. The
|
||||||
|
# regex matches against base names, not paths.
|
||||||
|
ignore-patterns=
|
||||||
|
|
||||||
|
# Python code to execute, usually for sys.path manipulation such as
|
||||||
|
# pygtk.require().
|
||||||
|
#init-hook=
|
||||||
|
|
||||||
|
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||||
|
# number of processors available to use.
|
||||||
|
jobs=1
|
||||||
|
|
||||||
|
# Control the amount of potential inferred values when inferring a single
|
||||||
|
# object. This can help the performance when dealing with large functions or
|
||||||
|
# complex, nested conditions.
|
||||||
|
limit-inference-results=100
|
||||||
|
|
||||||
|
# List of plugins (as comma separated values of python modules names) to load,
|
||||||
|
# usually to register additional checkers.
|
||||||
|
load-plugins=
|
||||||
|
|
||||||
|
# Pickle collected data for later comparisons.
|
||||||
|
persistent=yes
|
||||||
|
|
||||||
|
# Specify a configuration file.
|
||||||
|
#rcfile=
|
||||||
|
|
||||||
|
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||||
|
# user-friendly hints instead of false-positive error messages.
|
||||||
|
suggestion-mode=yes
|
||||||
|
|
||||||
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
unsafe-load-any-extension=no
|
||||||
|
|
||||||
|
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
|
||||||
|
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||||
|
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
||||||
|
confidence=
|
||||||
|
|
||||||
|
# Disable the message, report, category or checker with the given id(s). You
|
||||||
|
# can either give multiple identifiers separated by comma (,) or put this
|
||||||
|
# option multiple times (only on the command line, not in the configuration
|
||||||
|
# file where it should appear only once). You can also use "--disable=all" to
|
||||||
|
# disable everything first and then reenable specific checks. For example, if
|
||||||
|
# you want to run only the similarities checker, you can use "--disable=all
|
||||||
|
# --enable=similarities". If you want to run only the classes checker, but have
|
||||||
|
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||||
|
# --disable=W".
|
||||||
|
disable=print-statement,
|
||||||
|
parameter-unpacking,
|
||||||
|
unpacking-in-except,
|
||||||
|
old-raise-syntax,
|
||||||
|
backtick,
|
||||||
|
long-suffix,
|
||||||
|
old-ne-operator,
|
||||||
|
old-octal-literal,
|
||||||
|
import-star-module-level,
|
||||||
|
non-ascii-bytes-literal,
|
||||||
|
raw-checker-failed,
|
||||||
|
bad-inline-option,
|
||||||
|
locally-disabled,
|
||||||
|
file-ignored,
|
||||||
|
suppressed-message,
|
||||||
|
useless-suppression,
|
||||||
|
deprecated-pragma,
|
||||||
|
use-symbolic-message-instead,
|
||||||
|
apply-builtin,
|
||||||
|
basestring-builtin,
|
||||||
|
buffer-builtin,
|
||||||
|
cmp-builtin,
|
||||||
|
coerce-builtin,
|
||||||
|
execfile-builtin,
|
||||||
|
file-builtin,
|
||||||
|
long-builtin,
|
||||||
|
raw_input-builtin,
|
||||||
|
reduce-builtin,
|
||||||
|
standarderror-builtin,
|
||||||
|
unicode-builtin,
|
||||||
|
xrange-builtin,
|
||||||
|
coerce-method,
|
||||||
|
delslice-method,
|
||||||
|
getslice-method,
|
||||||
|
setslice-method,
|
||||||
|
no-absolute-import,
|
||||||
|
old-division,
|
||||||
|
dict-iter-method,
|
||||||
|
dict-view-method,
|
||||||
|
next-method-called,
|
||||||
|
metaclass-assignment,
|
||||||
|
indexing-exception,
|
||||||
|
raising-string,
|
||||||
|
reload-builtin,
|
||||||
|
oct-method,
|
||||||
|
hex-method,
|
||||||
|
nonzero-method,
|
||||||
|
cmp-method,
|
||||||
|
input-builtin,
|
||||||
|
round-builtin,
|
||||||
|
intern-builtin,
|
||||||
|
unichr-builtin,
|
||||||
|
map-builtin-not-iterating,
|
||||||
|
zip-builtin-not-iterating,
|
||||||
|
range-builtin-not-iterating,
|
||||||
|
filter-builtin-not-iterating,
|
||||||
|
using-cmp-argument,
|
||||||
|
eq-without-hash,
|
||||||
|
div-method,
|
||||||
|
idiv-method,
|
||||||
|
rdiv-method,
|
||||||
|
exception-message-attribute,
|
||||||
|
invalid-str-codec,
|
||||||
|
sys-max-int,
|
||||||
|
bad-python3-import,
|
||||||
|
deprecated-string-function,
|
||||||
|
deprecated-str-translate-call,
|
||||||
|
deprecated-itertools-function,
|
||||||
|
deprecated-types-field,
|
||||||
|
next-method-defined,
|
||||||
|
dict-items-not-iterating,
|
||||||
|
dict-keys-not-iterating,
|
||||||
|
dict-values-not-iterating,
|
||||||
|
deprecated-operator-function,
|
||||||
|
deprecated-urllib-function,
|
||||||
|
xreadlines-attribute,
|
||||||
|
deprecated-sys-function,
|
||||||
|
exception-escape,
|
||||||
|
comprehension-escape
|
||||||
|
|
||||||
|
# Enable the message, report, category or checker with the given id(s). You can
|
||||||
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
|
enable=c-extension-no-member
|
||||||
|
|
||||||
|
|
||||||
|
[REPORTS]
|
||||||
|
|
||||||
|
# Python expression which should return a note less than 10 (10 is the highest
|
||||||
|
# note). You have access to the variables errors warning, statement which
|
||||||
|
# respectively contain the number of errors / warnings messages and the total
|
||||||
|
# number of statements analyzed. This is used by the global evaluation report
|
||||||
|
# (RP0004).
|
||||||
|
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||||
|
|
||||||
|
# Template used to display messages. This is a python new-style format string
|
||||||
|
# used to format the message information. See doc for all details.
|
||||||
|
#msg-template=
|
||||||
|
|
||||||
|
# Set the output format. Available formats are text, parseable, colorized, json
|
||||||
|
# and msvs (visual studio). You can also give a reporter class, e.g.
|
||||||
|
# mypackage.mymodule.MyReporterClass.
|
||||||
|
output-format=text
|
||||||
|
|
||||||
|
# Tells whether to display a full report or only the messages.
|
||||||
|
reports=no
|
||||||
|
|
||||||
|
# Activate the evaluation score.
|
||||||
|
score=yes
|
||||||
|
|
||||||
|
|
||||||
|
[REFACTORING]
|
||||||
|
|
||||||
|
# Maximum number of nested blocks for function / method body
|
||||||
|
max-nested-blocks=5
|
||||||
|
|
||||||
|
# Complete name of functions that never returns. When checking for
|
||||||
|
# inconsistent-return-statements if a never returning function is called then
|
||||||
|
# it will be considered as an explicit return statement and no message will be
|
||||||
|
# printed.
|
||||||
|
never-returning-functions=sys.exit
|
||||||
|
|
||||||
|
|
||||||
|
[BASIC]
|
||||||
|
|
||||||
|
# Naming style matching correct argument names.
|
||||||
|
argument-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct argument names. Overrides argument-
|
||||||
|
# naming-style.
|
||||||
|
#argument-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct attribute names.
|
||||||
|
attr-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||||
|
# style.
|
||||||
|
#attr-rgx=
|
||||||
|
|
||||||
|
# Bad variable names which should always be refused, separated by a comma.
|
||||||
|
bad-names=foo,
|
||||||
|
bar,
|
||||||
|
baz,
|
||||||
|
toto,
|
||||||
|
tutu,
|
||||||
|
tata
|
||||||
|
|
||||||
|
# Naming style matching correct class attribute names.
|
||||||
|
class-attribute-naming-style=any
|
||||||
|
|
||||||
|
# Regular expression matching correct class attribute names. Overrides class-
|
||||||
|
# attribute-naming-style.
|
||||||
|
#class-attribute-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct class names.
|
||||||
|
class-naming-style=PascalCase
|
||||||
|
|
||||||
|
# Regular expression matching correct class names. Overrides class-naming-
|
||||||
|
# style.
|
||||||
|
#class-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct constant names.
|
||||||
|
const-naming-style=UPPER_CASE
|
||||||
|
|
||||||
|
# Regular expression matching correct constant names. Overrides const-naming-
|
||||||
|
# style.
|
||||||
|
#const-rgx=
|
||||||
|
|
||||||
|
# Minimum line length for functions/classes that require docstrings, shorter
|
||||||
|
# ones are exempt.
|
||||||
|
docstring-min-length=-1
|
||||||
|
|
||||||
|
# Naming style matching correct function names.
|
||||||
|
function-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct function names. Overrides function-
|
||||||
|
# naming-style.
|
||||||
|
#function-rgx=
|
||||||
|
|
||||||
|
# Good variable names which should always be accepted, separated by a comma.
|
||||||
|
good-names=i,
|
||||||
|
j,
|
||||||
|
k,
|
||||||
|
ex,
|
||||||
|
Run,
|
||||||
|
_
|
||||||
|
|
||||||
|
# Include a hint for the correct naming format with invalid-name.
|
||||||
|
include-naming-hint=no
|
||||||
|
|
||||||
|
# Naming style matching correct inline iteration names.
|
||||||
|
inlinevar-naming-style=any
|
||||||
|
|
||||||
|
# Regular expression matching correct inline iteration names. Overrides
|
||||||
|
# inlinevar-naming-style.
|
||||||
|
#inlinevar-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct method names.
|
||||||
|
method-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct method names. Overrides method-naming-
|
||||||
|
# style.
|
||||||
|
#method-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct module names.
|
||||||
|
module-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct module names. Overrides module-naming-
|
||||||
|
# style.
|
||||||
|
#module-rgx=
|
||||||
|
|
||||||
|
# Colon-delimited sets of names that determine each other's naming style when
|
||||||
|
# the name regexes allow several styles.
|
||||||
|
name-group=
|
||||||
|
|
||||||
|
# Regular expression which should only match function or class names that do
|
||||||
|
# not require a docstring.
|
||||||
|
no-docstring-rgx=^_
|
||||||
|
|
||||||
|
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||||
|
# to this list to register other decorators that produce valid properties.
|
||||||
|
# These decorators are taken in consideration only for invalid-name.
|
||||||
|
property-classes=abc.abstractproperty
|
||||||
|
|
||||||
|
# Naming style matching correct variable names.
|
||||||
|
variable-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct variable names. Overrides variable-
|
||||||
|
# naming-style.
|
||||||
|
#variable-rgx=
|
||||||
|
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
|
||||||
|
# Format style used to check logging format string. `old` means using %
|
||||||
|
# formatting, while `new` is for `{}` formatting.
|
||||||
|
logging-format-style=old
|
||||||
|
|
||||||
|
# Logging modules to check that the string format arguments are in logging
|
||||||
|
# function parameter format.
|
||||||
|
logging-modules=logging
|
||||||
|
|
||||||
|
|
||||||
|
[SIMILARITIES]
|
||||||
|
|
||||||
|
# Ignore comments when computing similarities.
|
||||||
|
ignore-comments=yes
|
||||||
|
|
||||||
|
# Ignore docstrings when computing similarities.
|
||||||
|
ignore-docstrings=yes
|
||||||
|
|
||||||
|
# Ignore imports when computing similarities.
|
||||||
|
ignore-imports=no
|
||||||
|
|
||||||
|
# Minimum lines number of a similarity.
|
||||||
|
min-similarity-lines=4
|
||||||
|
|
||||||
|
|
||||||
|
[SPELLING]
|
||||||
|
|
||||||
|
# Limits count of emitted suggestions for spelling mistakes.
|
||||||
|
max-spelling-suggestions=4
|
||||||
|
|
||||||
|
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||||
|
# install python-enchant package..
|
||||||
|
spelling-dict=
|
||||||
|
|
||||||
|
# List of comma separated words that should not be checked.
|
||||||
|
spelling-ignore-words=
|
||||||
|
|
||||||
|
# A path to a file that contains private dictionary; one word per line.
|
||||||
|
spelling-private-dict-file=
|
||||||
|
|
||||||
|
# Tells whether to store unknown words to indicated private dictionary in
|
||||||
|
# --spelling-private-dict-file option instead of raising a message.
|
||||||
|
spelling-store-unknown-words=no
|
||||||
|
|
||||||
|
|
||||||
|
[FORMAT]
|
||||||
|
|
||||||
|
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||||
|
expected-line-ending-format=
|
||||||
|
|
||||||
|
# Regexp for a line that is allowed to be longer than the limit.
|
||||||
|
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||||
|
|
||||||
|
# Number of spaces of indent required inside a hanging or continued line.
|
||||||
|
indent-after-paren=4
|
||||||
|
|
||||||
|
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||||
|
# tab).
|
||||||
|
indent-string=' '
|
||||||
|
|
||||||
|
# Maximum number of characters on a single line.
|
||||||
|
max-line-length=100
|
||||||
|
|
||||||
|
# Maximum number of lines in a module.
|
||||||
|
max-module-lines=1000
|
||||||
|
|
||||||
|
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||||
|
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||||
|
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||||
|
# `empty-line` allows space-only lines.
|
||||||
|
no-space-check=trailing-comma,
|
||||||
|
dict-separator
|
||||||
|
|
||||||
|
# Allow the body of a class to be on the same line as the declaration if body
|
||||||
|
# contains single statement.
|
||||||
|
single-line-class-stmt=no
|
||||||
|
|
||||||
|
# Allow the body of an if to be on the same line as the test if there is no
|
||||||
|
# else.
|
||||||
|
single-line-if-stmt=no
|
||||||
|
|
||||||
|
|
||||||
|
[STRING]
|
||||||
|
|
||||||
|
# This flag controls whether the implicit-str-concat-in-sequence should
|
||||||
|
# generate a warning on implicit string concatenation in sequences defined over
|
||||||
|
# several lines.
|
||||||
|
check-str-concat-over-line-jumps=no
|
||||||
|
|
||||||
|
|
||||||
|
[VARIABLES]
|
||||||
|
|
||||||
|
# List of additional names supposed to be defined in builtins. Remember that
|
||||||
|
# you should avoid defining new builtins when possible.
|
||||||
|
additional-builtins=
|
||||||
|
|
||||||
|
# Tells whether unused global variables should be treated as a violation.
|
||||||
|
allow-global-unused-variables=yes
|
||||||
|
|
||||||
|
# List of strings which can identify a callback function by name. A callback
|
||||||
|
# name must start or end with one of those strings.
|
||||||
|
callbacks=cb_,
|
||||||
|
_cb
|
||||||
|
|
||||||
|
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||||
|
# not be used).
|
||||||
|
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||||
|
|
||||||
|
# Argument names that match this expression will be ignored. Default to name
|
||||||
|
# with leading underscore.
|
||||||
|
ignored-argument-names=_.*|^ignored_|^unused_
|
||||||
|
|
||||||
|
# Tells whether we should check for unused import in __init__ files.
|
||||||
|
init-import=no
|
||||||
|
|
||||||
|
# List of qualified module names which can have objects that can redefine
|
||||||
|
# builtins.
|
||||||
|
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||||
|
|
||||||
|
|
||||||
|
[TYPECHECK]
|
||||||
|
|
||||||
|
# List of decorators that produce context managers, such as
|
||||||
|
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||||
|
# produce valid context managers.
|
||||||
|
contextmanager-decorators=contextlib.contextmanager
|
||||||
|
|
||||||
|
# List of members which are set dynamically and missed by pylint inference
|
||||||
|
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||||
|
# expressions are accepted.
|
||||||
|
generated-members=
|
||||||
|
|
||||||
|
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||||
|
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||||
|
ignore-mixin-members=yes
|
||||||
|
|
||||||
|
# Tells whether to warn about missing members when the owner of the attribute
|
||||||
|
# is inferred to be None.
|
||||||
|
ignore-none=yes
|
||||||
|
|
||||||
|
# This flag controls whether pylint should warn about no-member and similar
|
||||||
|
# checks whenever an opaque object is returned when inferring. The inference
|
||||||
|
# can return multiple potential results while evaluating a Python object, but
|
||||||
|
# some branches might not be evaluated, which results in partial inference. In
|
||||||
|
# that case, it might be useful to still emit no-member and other checks for
|
||||||
|
# the rest of the inferred objects.
|
||||||
|
ignore-on-opaque-inference=yes
|
||||||
|
|
||||||
|
# List of class names for which member attributes should not be checked (useful
|
||||||
|
# for classes with dynamically set attributes). This supports the use of
|
||||||
|
# qualified names.
|
||||||
|
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||||
|
|
||||||
|
# List of module names for which member attributes should not be checked
|
||||||
|
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||||
|
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||||
|
# supports qualified module names, as well as Unix pattern matching.
|
||||||
|
ignored-modules=
|
||||||
|
|
||||||
|
# Show a hint with possible names when a member name was not found. The aspect
|
||||||
|
# of finding the hint is based on edit distance.
|
||||||
|
missing-member-hint=yes
|
||||||
|
|
||||||
|
# The minimum edit distance a name should have in order to be considered a
|
||||||
|
# similar match for a missing member name.
|
||||||
|
missing-member-hint-distance=1
|
||||||
|
|
||||||
|
# The total number of similar names that should be taken in consideration when
|
||||||
|
# showing a hint for a missing member.
|
||||||
|
missing-member-max-choices=1
|
||||||
|
|
||||||
|
|
||||||
|
[MISCELLANEOUS]
|
||||||
|
|
||||||
|
# List of note tags to take in consideration, separated by a comma.
|
||||||
|
notes=FIXME,
|
||||||
|
XXX,
|
||||||
|
TODO
|
||||||
|
|
||||||
|
|
||||||
|
[DESIGN]
|
||||||
|
|
||||||
|
# Maximum number of arguments for function / method.
|
||||||
|
max-args=5
|
||||||
|
|
||||||
|
# Maximum number of attributes for a class (see R0902).
|
||||||
|
max-attributes=7
|
||||||
|
|
||||||
|
# Maximum number of boolean expressions in an if statement.
|
||||||
|
max-bool-expr=5
|
||||||
|
|
||||||
|
# Maximum number of branch for function / method body.
|
||||||
|
max-branches=12
|
||||||
|
|
||||||
|
# Maximum number of locals for function / method body.
|
||||||
|
max-locals=15
|
||||||
|
|
||||||
|
# Maximum number of parents for a class (see R0901).
|
||||||
|
max-parents=7
|
||||||
|
|
||||||
|
# Maximum number of public methods for a class (see R0904).
|
||||||
|
max-public-methods=20
|
||||||
|
|
||||||
|
# Maximum number of return / yield for function / method body.
|
||||||
|
max-returns=6
|
||||||
|
|
||||||
|
# Maximum number of statements in function / method body.
|
||||||
|
max-statements=50
|
||||||
|
|
||||||
|
# Minimum number of public methods for a class (see R0903).
|
||||||
|
min-public-methods=2
|
||||||
|
|
||||||
|
|
||||||
|
[CLASSES]
|
||||||
|
|
||||||
|
# List of method names used to declare (i.e. assign) instance attributes.
|
||||||
|
defining-attr-methods=__init__,
|
||||||
|
__new__,
|
||||||
|
setUp
|
||||||
|
|
||||||
|
# List of member names, which should be excluded from the protected access
|
||||||
|
# warning.
|
||||||
|
exclude-protected=_asdict,
|
||||||
|
_fields,
|
||||||
|
_replace,
|
||||||
|
_source,
|
||||||
|
_make
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a class method.
|
||||||
|
valid-classmethod-first-arg=cls
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a metaclass class method.
|
||||||
|
valid-metaclass-classmethod-first-arg=cls
|
||||||
|
|
||||||
|
|
||||||
|
[IMPORTS]
|
||||||
|
|
||||||
|
# Allow wildcard imports from modules that define __all__.
|
||||||
|
allow-wildcard-with-all=no
|
||||||
|
|
||||||
|
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||||
|
# 3 compatible code, which means that the block might have code that exists
|
||||||
|
# only in one or another interpreter, leading to false positives when analysed.
|
||||||
|
analyse-fallback-blocks=no
|
||||||
|
|
||||||
|
# Deprecated modules which should not be used, separated by a comma.
|
||||||
|
deprecated-modules=optparse,tkinter.tix
|
||||||
|
|
||||||
|
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled).
|
||||||
|
ext-import-graph=
|
||||||
|
|
||||||
|
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||||
|
# given file (report RP0402 must not be disabled).
|
||||||
|
import-graph=
|
||||||
|
|
||||||
|
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled).
|
||||||
|
int-import-graph=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of the standard
|
||||||
|
# compatibility libraries.
|
||||||
|
known-standard-library=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of a third party library.
|
||||||
|
known-third-party=enchant
|
||||||
|
|
||||||
|
|
||||||
|
[EXCEPTIONS]
|
||||||
|
|
||||||
|
# Exceptions that will emit a warning when being caught. Defaults to
|
||||||
|
# "BaseException, Exception".
|
||||||
|
overgeneral-exceptions=BaseException,
|
||||||
|
Exception
|
7
docs/full/lib.image.rst
Normal file
7
docs/full/lib.image.rst
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
lib.image module
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. automodule:: lib.image
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
|
@ -8,6 +8,8 @@ Subpackages
|
||||||
|
|
||||||
lib.model
|
lib.model
|
||||||
lib.faces_detect
|
lib.faces_detect
|
||||||
|
lib.image
|
||||||
|
lib.training_data
|
||||||
|
|
||||||
Module contents
|
Module contents
|
||||||
---------------
|
---------------
|
||||||
|
|
7
docs/full/lib.training_data.rst
Normal file
7
docs/full/lib.training_data.rst
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
lib.training\_data module
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. automodule:: lib.training_data
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
|
@ -7,7 +7,7 @@ faceswap.dev Developer Documentation
|
||||||
====================================
|
====================================
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 4
|
:maxdepth: 2
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
|
||||||
full/modules
|
full/modules
|
||||||
|
|
|
@ -8,8 +8,8 @@ from datetime import datetime
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
from lib.faces_detect import rotate_landmarks
|
||||||
from lib import Serializer
|
from lib import Serializer
|
||||||
from lib.utils import rotate_landmarks
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from lib.vgg_face import VGGFace
|
from lib.vgg_face import VGGFace
|
||||||
from lib.utils import cv2_read_img
|
from lib.image import read_image
|
||||||
from plugins.extract.pipeline import Extractor
|
from plugins.extract.pipeline import Extractor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -47,10 +47,10 @@ class FaceFilter():
|
||||||
""" Load the images """
|
""" Load the images """
|
||||||
retval = dict()
|
retval = dict()
|
||||||
for fpath in reference_file_paths:
|
for fpath in reference_file_paths:
|
||||||
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
|
retval[fpath] = {"image": read_image(fpath, raise_error=True),
|
||||||
"type": "filter"}
|
"type": "filter"}
|
||||||
for fpath in nreference_file_paths:
|
for fpath in nreference_file_paths:
|
||||||
retval[fpath] = {"image": cv2_read_img(fpath, raise_error=True),
|
retval[fpath] = {"image": read_image(fpath, raise_error=True),
|
||||||
"type": "nfilter"}
|
"type": "nfilter"}
|
||||||
logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()})
|
logger.debug("Loaded filter images: %s", {k: v["type"] for k, v in retval.items()})
|
||||||
return retval
|
return retval
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
""" Face and landmarks detection for faceswap.py """
|
""" Face and landmarks detection for faceswap.py """
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling
|
from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling
|
||||||
|
@ -399,3 +400,89 @@ class DetectedFace():
|
||||||
if not self.reference:
|
if not self.reference:
|
||||||
return None
|
return None
|
||||||
return get_matrix_scaling(self.reference_matrix)
|
return get_matrix_scaling(self.reference_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_landmarks(face, rotation_matrix):
|
||||||
|
""" Rotates the 68 point landmarks and detection bounding box around the given rotation matrix.
|
||||||
|
|
||||||
|
Paramaters
|
||||||
|
----------
|
||||||
|
face: DetectedFace or dict
|
||||||
|
A :class:`DetectedFace` or an `alignments file` ``dict`` containing the 68 point landmarks
|
||||||
|
and the `x`, `w`, `y`, `h` detection bounding box points.
|
||||||
|
rotation_matrix: numpy.ndarray
|
||||||
|
The rotation matrix to rotate the given object by.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DetectedFace or dict
|
||||||
|
The rotated :class:`DetectedFace` or `alignments file` ``dict`` with the landmarks and
|
||||||
|
detection bounding box points rotated by the given matrix. The return type is the same as
|
||||||
|
the input type for ``face``
|
||||||
|
"""
|
||||||
|
logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s",
|
||||||
|
rotation_matrix, type(face))
|
||||||
|
rotated_landmarks = None
|
||||||
|
# Detected Face Object
|
||||||
|
if isinstance(face, DetectedFace):
|
||||||
|
bounding_box = [[face.x, face.y],
|
||||||
|
[face.x + face.w, face.y],
|
||||||
|
[face.x + face.w, face.y + face.h],
|
||||||
|
[face.x, face.y + face.h]]
|
||||||
|
landmarks = face.landmarks_xy
|
||||||
|
|
||||||
|
# Alignments Dict
|
||||||
|
elif isinstance(face, dict) and "x" in face:
|
||||||
|
bounding_box = [[face.get("x", 0), face.get("y", 0)],
|
||||||
|
[face.get("x", 0) + face.get("w", 0),
|
||||||
|
face.get("y", 0)],
|
||||||
|
[face.get("x", 0) + face.get("w", 0),
|
||||||
|
face.get("y", 0) + face.get("h", 0)],
|
||||||
|
[face.get("x", 0),
|
||||||
|
face.get("y", 0) + face.get("h", 0)]]
|
||||||
|
landmarks = face.get("landmarks_xy", list())
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported face type")
|
||||||
|
|
||||||
|
logger.trace("Original landmarks: %s", landmarks)
|
||||||
|
|
||||||
|
rotation_matrix = cv2.invertAffineTransform(
|
||||||
|
rotation_matrix)
|
||||||
|
rotated = list()
|
||||||
|
for item in (bounding_box, landmarks):
|
||||||
|
if not item:
|
||||||
|
continue
|
||||||
|
points = np.array(item, np.int32)
|
||||||
|
points = np.expand_dims(points, axis=0)
|
||||||
|
transformed = cv2.transform(points,
|
||||||
|
rotation_matrix).astype(np.int32)
|
||||||
|
rotated.append(transformed.squeeze())
|
||||||
|
|
||||||
|
# Bounding box should follow x, y planes, so get min/max
|
||||||
|
# for non-90 degree rotations
|
||||||
|
pt_x = min([pnt[0] for pnt in rotated[0]])
|
||||||
|
pt_y = min([pnt[1] for pnt in rotated[0]])
|
||||||
|
pt_x1 = max([pnt[0] for pnt in rotated[0]])
|
||||||
|
pt_y1 = max([pnt[1] for pnt in rotated[0]])
|
||||||
|
width = pt_x1 - pt_x
|
||||||
|
height = pt_y1 - pt_y
|
||||||
|
|
||||||
|
if isinstance(face, DetectedFace):
|
||||||
|
face.x = int(pt_x)
|
||||||
|
face.y = int(pt_y)
|
||||||
|
face.w = int(width)
|
||||||
|
face.h = int(height)
|
||||||
|
face.r = 0
|
||||||
|
if len(rotated) > 1:
|
||||||
|
rotated_landmarks = [tuple(point) for point in rotated[1].tolist()]
|
||||||
|
face.landmarks_xy = rotated_landmarks
|
||||||
|
else:
|
||||||
|
face["left"] = int(pt_x)
|
||||||
|
face["top"] = int(pt_y)
|
||||||
|
face["right"] = int(pt_x1)
|
||||||
|
face["bottom"] = int(pt_y1)
|
||||||
|
rotated_landmarks = face
|
||||||
|
|
||||||
|
logger.trace("Rotated landmarks: %s", rotated_landmarks)
|
||||||
|
return face
|
||||||
|
|
302
lib/image.py
Normal file
302
lib/image.py
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
#!/usr/bin python3
|
||||||
|
""" Utilities for working with images and videos """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from concurrent import futures
|
||||||
|
from hashlib import sha1
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import imageio_ffmpeg as im_ffm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lib.utils import convert_to_secs, FaceswapError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||||
|
|
||||||
|
# ################### #
|
||||||
|
# <<< IMAGE UTILS >>> #
|
||||||
|
# ################### #
|
||||||
|
|
||||||
|
|
||||||
|
# <<< IMAGE IO >>> #
|
||||||
|
|
||||||
|
def read_image(filename, raise_error=False):
|
||||||
|
""" Read an image file from a file location.
|
||||||
|
|
||||||
|
Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually
|
||||||
|
loaded. Errors can be logged and ignored so that the process can continue on an image load
|
||||||
|
failure.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename: str
|
||||||
|
Full path to the image to be loaded.
|
||||||
|
raise_error: bool, optional
|
||||||
|
If ``True``, then any failures (including the returned image being ``None``) will be
|
||||||
|
raised. If ``False`` then an error message will be logged, but the error will not be
|
||||||
|
raised. Default: ``False``
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
numpy.ndarray
|
||||||
|
The image in `BGR` channel order.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> image_file = "/path/to/image.png"
|
||||||
|
>>> try:
|
||||||
|
>>> image = read_image(image_file, raise_error=True)
|
||||||
|
>>> except:
|
||||||
|
>>> raise ValueError("There was an error")
|
||||||
|
"""
|
||||||
|
logger.trace("Requested image: '%s'", filename)
|
||||||
|
success = True
|
||||||
|
image = None
|
||||||
|
try:
|
||||||
|
image = cv2.imread(filename)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError
|
||||||
|
except TypeError:
|
||||||
|
success = False
|
||||||
|
msg = "Error while reading image (TypeError): '{}'".format(filename)
|
||||||
|
logger.error(msg)
|
||||||
|
if raise_error:
|
||||||
|
raise Exception(msg)
|
||||||
|
except ValueError:
|
||||||
|
success = False
|
||||||
|
msg = ("Error while reading image. This is most likely caused by special characters in "
|
||||||
|
"the filename: '{}'".format(filename))
|
||||||
|
logger.error(msg)
|
||||||
|
if raise_error:
|
||||||
|
raise Exception(msg)
|
||||||
|
except Exception as err: # pylint:disable=broad-except
|
||||||
|
success = False
|
||||||
|
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
|
||||||
|
logger.error(msg)
|
||||||
|
if raise_error:
|
||||||
|
raise Exception(msg)
|
||||||
|
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_batch(filenames):
|
||||||
|
""" Load a batch of images from the given file locations.
|
||||||
|
|
||||||
|
Leverages multi-threading to load multiple images from disk at the same time
|
||||||
|
leading to vastly reduced image read times.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filenames: list
|
||||||
|
A list of ``str`` full paths to the images to be loaded.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
numpy.ndarray
|
||||||
|
The batch of images in `BGR` channel order.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
As the images are compiled into a batch, they must be all of the same dimensions.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"]
|
||||||
|
>>> images = read_image_batch(image_filenames)
|
||||||
|
"""
|
||||||
|
logger.trace("Requested batch: '%s'", filenames)
|
||||||
|
executor = futures.ThreadPoolExecutor()
|
||||||
|
with executor:
|
||||||
|
images = [executor.submit(read_image, filename, raise_error=True)
|
||||||
|
for filename in filenames]
|
||||||
|
batch = np.array([future.result() for future in futures.as_completed(images)])
|
||||||
|
logger.trace("Returning images: %s", batch.shape)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_hash(filename):
|
||||||
|
""" Return the `sha1` hash of an image saved on disk.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename: str
|
||||||
|
Full path to the image to be loaded.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the given image.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> image_file = "/path/to/image.png"
|
||||||
|
>>> image_hash = read_image_hash(image_file)
|
||||||
|
"""
|
||||||
|
img = read_image(filename, raise_error=True)
|
||||||
|
image_hash = sha1(img).hexdigest()
|
||||||
|
logger.trace("filename: '%s', hash: %s", filename, image_hash)
|
||||||
|
return image_hash
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_with_hash(image, extension):
|
||||||
|
""" Encode an image, and get the encoded image back with its `sha1` hash.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image: numpy.ndarray
|
||||||
|
The image to be encoded in `BGR` channel order.
|
||||||
|
extension: str
|
||||||
|
A compatible `cv2` image file extension that the final image is to be saved to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
image_hash: str
|
||||||
|
The :func:`hashlib.hexdigest()` representation of the `sha1` hash of the encoded image
|
||||||
|
encoded_image: bytes
|
||||||
|
The image encoded into the correct file format
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> image_file = "/path/to/image.png"
|
||||||
|
>>> image = read_image(image_file)
|
||||||
|
>>> image_hash, encoded_image = encode_image_with_hash(image, ".jpg")
|
||||||
|
"""
|
||||||
|
encoded_image = cv2.imencode(extension, image)[1]
|
||||||
|
image_hash = sha1(cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED)).hexdigest()
|
||||||
|
return image_hash, encoded_image
|
||||||
|
|
||||||
|
|
||||||
|
def batch_convert_color(batch, colorspace):
|
||||||
|
""" Convert a batch of images from one color space to another.
|
||||||
|
|
||||||
|
Converts a batch of images by reshaping the batch prior to conversion rather than iterating
|
||||||
|
over the images. This leads to a significant speed up in the convert process.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch: numpy.ndarray
|
||||||
|
A batch of images.
|
||||||
|
colorspace: str
|
||||||
|
The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be
|
||||||
|
``'BGR2LAB'``.
|
||||||
|
See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full
|
||||||
|
list of color codes.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
numpy.ndarray
|
||||||
|
The batch converted to the requested color space.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> images_bgr = numpy.array([image1, image2, image3])
|
||||||
|
>>> images_lab = batch_convert_color(images_bgr, "BGR2LAB")
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function is only compatible for color space conversions that have the same image shape
|
||||||
|
for source and destination color spaces.
|
||||||
|
|
||||||
|
If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some
|
||||||
|
information lost. For many cases, this will not be noticeable but it is recommended
|
||||||
|
to use 32-bit images in cases that need the full range of colors or that convert an image
|
||||||
|
before an operation and then convert back.
|
||||||
|
"""
|
||||||
|
logger.trace("Batch converting: (batch shape: %s, colorspace: %s)", batch.shape, colorspace)
|
||||||
|
original_shape = batch.shape
|
||||||
|
batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:]))
|
||||||
|
batch = cv2.cvtColor(batch, getattr(cv2, "COLOR_{}".format(colorspace)))
|
||||||
|
return batch.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# ################### #
|
||||||
|
# <<< VIDEO UTILS >>> #
|
||||||
|
# ################### #
|
||||||
|
|
||||||
|
def count_frames_and_secs(filename, timeout=60):
|
||||||
|
""" Count the number of frames and seconds in a video file.
|
||||||
|
|
||||||
|
Adapted From :mod:`ffmpeg_imageio` to handle the issue of ffmpeg occasionally hanging
|
||||||
|
inside a subprocess.
|
||||||
|
|
||||||
|
If the operation times out then the process will try to read the data again, up to a total
|
||||||
|
of 3 times. If the data still cannot be read then an exception will be raised.
|
||||||
|
|
||||||
|
Note that this operation can be quite slow for large files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename: str
|
||||||
|
Full path to the video to be analyzed.
|
||||||
|
timeout: str, optional
|
||||||
|
The amount of time in seconds to wait for the video data before aborting.
|
||||||
|
Default: ``60``
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
nframes: int
|
||||||
|
The number of frames in the given video file.
|
||||||
|
nsecs: float
|
||||||
|
The duration, in seconds, of the given video file.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> video = "/path/to/video.mp4"
|
||||||
|
>>> frames, secs = count_frames_and_secs(video)
|
||||||
|
"""
|
||||||
|
# https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg
|
||||||
|
|
||||||
|
assert isinstance(filename, str), "Video path must be a string"
|
||||||
|
exe = im_ffm.get_ffmpeg_exe()
|
||||||
|
iswin = sys.platform.startswith("win")
|
||||||
|
logger.debug("iswin: '%s'", iswin)
|
||||||
|
cmd = [exe, "-i", filename, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"]
|
||||||
|
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
|
||||||
|
attempts = 3
|
||||||
|
for attempt in range(attempts):
|
||||||
|
try:
|
||||||
|
logger.debug("attempt: %s of %s", attempt + 1, attempts)
|
||||||
|
out = subprocess.check_output(cmd,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
shell=iswin,
|
||||||
|
timeout=timeout)
|
||||||
|
logger.debug("Succesfully communicated with FFMPEG")
|
||||||
|
break
|
||||||
|
except subprocess.CalledProcessError as err:
|
||||||
|
out = err.output.decode(errors="ignore")
|
||||||
|
raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out))
|
||||||
|
except subprocess.TimeoutExpired as err:
|
||||||
|
this_attempt = attempt + 1
|
||||||
|
if this_attempt == attempts:
|
||||||
|
msg = ("FFMPEG hung while attempting to obtain the frame count. "
|
||||||
|
"Sometimes this issue resolves itself, so you can try running again. "
|
||||||
|
"Otherwise use the Effmpeg Tool to extract the frames from your video into "
|
||||||
|
"a folder, and then run the requested Faceswap process on that folder.")
|
||||||
|
raise FaceswapError(msg) from err
|
||||||
|
logger.warning("FFMPEG hung while attempting to obtain the frame count. "
|
||||||
|
"Retrying %s of %s", this_attempt + 1, attempts)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note that other than with the subprocess calls below, ffmpeg wont hang here.
|
||||||
|
# Worst case Python will stop/crash and ffmpeg will continue running until done.
|
||||||
|
|
||||||
|
nframes = nsecs = None
|
||||||
|
for line in reversed(out.splitlines()):
|
||||||
|
if not line.startswith(b"frame="):
|
||||||
|
continue
|
||||||
|
line = line.decode(errors="ignore")
|
||||||
|
logger.debug("frame line: '%s'", line)
|
||||||
|
idx = line.find("frame=")
|
||||||
|
if idx >= 0:
|
||||||
|
splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
|
||||||
|
nframes = int(splitframes)
|
||||||
|
idx = line.find("time=")
|
||||||
|
if idx >= 0:
|
||||||
|
splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
|
||||||
|
nsecs = convert_to_secs(*splittime.split(":"))
|
||||||
|
logger.debug("nframes: %s, nsecs: %s", nframes, nsecs)
|
||||||
|
return nframes, nsecs
|
||||||
|
|
||||||
|
raise RuntimeError("Could not get number of frames") # pragma: no cover
|
|
@ -43,6 +43,8 @@ class Mask():
|
||||||
self.__class__.__name__, face.shape, channels, landmarks)
|
self.__class__.__name__, face.shape, channels, landmarks)
|
||||||
self.landmarks = landmarks
|
self.landmarks = landmarks
|
||||||
self.face = face
|
self.face = face
|
||||||
|
self.dtype = face.dtype
|
||||||
|
self.threshold = 255 if self.dtype == "uint8" else 255.0
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
mask = self.build_mask()
|
mask = self.build_mask()
|
||||||
|
@ -73,7 +75,7 @@ class Mask():
|
||||||
class dfl_full(Mask): # pylint: disable=invalid-name
|
class dfl_full(Mask): # pylint: disable=invalid-name
|
||||||
""" DFL facial mask """
|
""" DFL facial mask """
|
||||||
def build_mask(self):
|
def build_mask(self):
|
||||||
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
|
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
|
||||||
|
|
||||||
nose_ridge = (self.landmarks[27:31], self.landmarks[33:34])
|
nose_ridge = (self.landmarks[27:31], self.landmarks[33:34])
|
||||||
jaw = (self.landmarks[0:17],
|
jaw = (self.landmarks[0:17],
|
||||||
|
@ -90,14 +92,14 @@ class dfl_full(Mask): # pylint: disable=invalid-name
|
||||||
|
|
||||||
for item in parts:
|
for item in parts:
|
||||||
merged = np.concatenate(item)
|
merged = np.concatenate(item)
|
||||||
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
|
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class components(Mask): # pylint: disable=invalid-name
|
class components(Mask): # pylint: disable=invalid-name
|
||||||
""" Component model mask """
|
""" Component model mask """
|
||||||
def build_mask(self):
|
def build_mask(self):
|
||||||
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
|
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
|
||||||
|
|
||||||
r_jaw = (self.landmarks[0:9], self.landmarks[17:18])
|
r_jaw = (self.landmarks[0:9], self.landmarks[17:18])
|
||||||
l_jaw = (self.landmarks[8:17], self.landmarks[26:27])
|
l_jaw = (self.landmarks[8:17], self.landmarks[26:27])
|
||||||
|
@ -117,7 +119,7 @@ class components(Mask): # pylint: disable=invalid-name
|
||||||
|
|
||||||
for item in parts:
|
for item in parts:
|
||||||
merged = np.concatenate(item)
|
merged = np.concatenate(item)
|
||||||
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
|
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +128,7 @@ class extended(Mask): # pylint: disable=invalid-name
|
||||||
Based on components mask. Attempts to extend the eyebrow points up the forehead
|
Based on components mask. Attempts to extend the eyebrow points up the forehead
|
||||||
"""
|
"""
|
||||||
def build_mask(self):
|
def build_mask(self):
|
||||||
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
|
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
|
||||||
|
|
||||||
landmarks = self.landmarks.copy()
|
landmarks = self.landmarks.copy()
|
||||||
# mid points between the side of face and eye point
|
# mid points between the side of face and eye point
|
||||||
|
@ -161,15 +163,15 @@ class extended(Mask): # pylint: disable=invalid-name
|
||||||
|
|
||||||
for item in parts:
|
for item in parts:
|
||||||
merged = np.concatenate(item)
|
merged = np.concatenate(item)
|
||||||
cv2.fillConvexPoly(mask, cv2.convexHull(merged), 255.) # pylint: disable=no-member
|
cv2.fillConvexPoly(mask, cv2.convexHull(merged), self.threshold)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class facehull(Mask): # pylint: disable=invalid-name
|
class facehull(Mask): # pylint: disable=invalid-name
|
||||||
""" Basic face hull mask """
|
""" Basic face hull mask """
|
||||||
def build_mask(self):
|
def build_mask(self):
|
||||||
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=np.float32)
|
mask = np.zeros(self.face.shape[0:2] + (1, ), dtype=self.dtype)
|
||||||
hull = cv2.convexHull( # pylint: disable=no-member
|
hull = cv2.convexHull(
|
||||||
np.array(self.landmarks).reshape((-1, 2)))
|
np.array(self.landmarks).reshape((-1, 2)))
|
||||||
cv2.fillConvexPoly(mask, hull, 255.0, lineType=cv2.LINE_AA) # pylint: disable=no-member
|
cv2.fillConvexPoly(mask, hull, self.threshold, lineType=cv2.LINE_AA)
|
||||||
return mask
|
return mask
|
||||||
|
|
File diff suppressed because it is too large
Load diff
242
lib/utils.py
242
lib/utils.py
|
@ -4,27 +4,18 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
import zipfile
|
import zipfile
|
||||||
from hashlib import sha1
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from re import finditer
|
from re import finditer
|
||||||
from multiprocessing import current_process
|
from multiprocessing import current_process
|
||||||
from socket import timeout as socket_timeout, error as socket_error
|
from socket import timeout as socket_timeout, error as socket_error
|
||||||
|
|
||||||
import imageio_ffmpeg as im_ffm
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
from lib.faces_detect import DetectedFace
|
|
||||||
|
|
||||||
|
|
||||||
# Global variables
|
# Global variables
|
||||||
_image_extensions = [ # pylint:disable=invalid-name
|
_image_extensions = [ # pylint:disable=invalid-name
|
||||||
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
|
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
|
||||||
|
@ -132,6 +123,22 @@ def get_image_paths(directory):
|
||||||
return dir_contents
|
return dir_contents
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_secs(*args):
|
||||||
|
""" converts a time to second. Either convert_to_secs(min, secs) or
|
||||||
|
convert_to_secs(hours, mins, secs). """
|
||||||
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||||
|
logger.debug("from time: %s", args)
|
||||||
|
retval = 0.0
|
||||||
|
if len(args) == 1:
|
||||||
|
retval = float(args[0])
|
||||||
|
elif len(args) == 2:
|
||||||
|
retval = 60 * float(args[0]) + float(args[1])
|
||||||
|
elif len(args) == 3:
|
||||||
|
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
|
||||||
|
logger.debug("to secs: %s", retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
def full_path_split(path):
|
def full_path_split(path):
|
||||||
""" Split a given path into all of it's separate components """
|
""" Split a given path into all of it's separate components """
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||||
|
@ -151,147 +158,6 @@ def full_path_split(path):
|
||||||
return allparts
|
return allparts
|
||||||
|
|
||||||
|
|
||||||
def cv2_read_img(filename, raise_error=False):
|
|
||||||
""" Read an image with cv2 and check that an image was actually loaded.
|
|
||||||
Logs an error if the image returned is None. or an error has occured.
|
|
||||||
|
|
||||||
Pass raise_error=True if error should be raised """
|
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
||||||
logger.trace("Requested image: '%s'", filename)
|
|
||||||
success = True
|
|
||||||
image = None
|
|
||||||
try:
|
|
||||||
image = cv2.imread(filename) # pylint:disable=no-member,c-extension-no-member
|
|
||||||
if image is None:
|
|
||||||
raise ValueError
|
|
||||||
except TypeError:
|
|
||||||
success = False
|
|
||||||
msg = "Error while reading image (TypeError): '{}'".format(filename)
|
|
||||||
logger.error(msg)
|
|
||||||
if raise_error:
|
|
||||||
raise Exception(msg)
|
|
||||||
except ValueError:
|
|
||||||
success = False
|
|
||||||
msg = ("Error while reading image. This is most likely caused by special characters in "
|
|
||||||
"the filename: '{}'".format(filename))
|
|
||||||
logger.error(msg)
|
|
||||||
if raise_error:
|
|
||||||
raise Exception(msg)
|
|
||||||
except Exception as err: # pylint:disable=broad-except
|
|
||||||
success = False
|
|
||||||
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
|
|
||||||
logger.error(msg)
|
|
||||||
if raise_error:
|
|
||||||
raise Exception(msg)
|
|
||||||
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def hash_image_file(filename):
|
|
||||||
""" Return an image file's sha1 hash """
|
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
||||||
img = cv2_read_img(filename, raise_error=True)
|
|
||||||
img_hash = sha1(img).hexdigest()
|
|
||||||
logger.trace("filename: '%s', hash: %s", filename, img_hash)
|
|
||||||
return img_hash
|
|
||||||
|
|
||||||
|
|
||||||
def hash_encode_image(image, extension):
|
|
||||||
""" Encode the image, get the hash and return the hash with
|
|
||||||
encoded image """
|
|
||||||
img = cv2.imencode(extension, image)[1] # pylint:disable=no-member,c-extension-no-member
|
|
||||||
f_hash = sha1(
|
|
||||||
cv2.imdecode( # pylint:disable=no-member,c-extension-no-member
|
|
||||||
img,
|
|
||||||
cv2.IMREAD_UNCHANGED)).hexdigest() # pylint:disable=no-member,c-extension-no-member
|
|
||||||
return f_hash, img
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_secs(*args):
|
|
||||||
""" converts a time to second. Either convert_to_secs(min, secs) or
|
|
||||||
convert_to_secs(hours, mins, secs). """
|
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
||||||
logger.debug("from time: %s", args)
|
|
||||||
retval = 0.0
|
|
||||||
if len(args) == 1:
|
|
||||||
retval = float(args[0])
|
|
||||||
elif len(args) == 2:
|
|
||||||
retval = 60 * float(args[0]) + float(args[1])
|
|
||||||
elif len(args) == 3:
|
|
||||||
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
|
|
||||||
logger.debug("to secs: %s", retval)
|
|
||||||
return retval
|
|
||||||
|
|
||||||
|
|
||||||
def count_frames_and_secs(path, timeout=60):
|
|
||||||
"""
|
|
||||||
Adapted From ffmpeg_imageio, to handle occasional hanging issue:
|
|
||||||
https://github.com/imageio/imageio-ffmpeg
|
|
||||||
|
|
||||||
Get the number of frames and number of seconds for the given video
|
|
||||||
file. Note that this operation can be quite slow for large files.
|
|
||||||
|
|
||||||
Disclaimer: I've seen this produce different results from actually reading
|
|
||||||
the frames with older versions of ffmpeg (2.x). Therefore I cannot say
|
|
||||||
with 100% certainty that the returned values are always exact.
|
|
||||||
"""
|
|
||||||
# https://stackoverflow.com/questions/2017843/fetch-frame-count-with-ffmpeg
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
||||||
assert isinstance(path, str), "Video path must be a string"
|
|
||||||
exe = im_ffm.get_ffmpeg_exe()
|
|
||||||
iswin = sys.platform.startswith("win")
|
|
||||||
logger.debug("iswin: '%s'", iswin)
|
|
||||||
cmd = [exe, "-i", path, "-map", "0:v:0", "-c", "copy", "-f", "null", "-"]
|
|
||||||
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
|
|
||||||
attempts = 3
|
|
||||||
for attempt in range(attempts):
|
|
||||||
try:
|
|
||||||
logger.debug("attempt: %s of %s", attempt + 1, attempts)
|
|
||||||
out = subprocess.check_output(cmd,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
shell=iswin,
|
|
||||||
timeout=timeout)
|
|
||||||
logger.debug("Succesfully communicated with FFMPEG")
|
|
||||||
break
|
|
||||||
except subprocess.CalledProcessError as err:
|
|
||||||
out = err.output.decode(errors="ignore")
|
|
||||||
raise RuntimeError("FFMEG call failed with {}:\n{}".format(err.returncode, out))
|
|
||||||
except subprocess.TimeoutExpired as err:
|
|
||||||
this_attempt = attempt + 1
|
|
||||||
if this_attempt == attempts:
|
|
||||||
msg = ("FFMPEG hung while attempting to obtain the frame count. "
|
|
||||||
"Sometimes this issue resolves itself, so you can try running again. "
|
|
||||||
"Otherwise use the Effmpeg Tool to extract the frames from your video into "
|
|
||||||
"a folder, and then run the requested Faceswap process on that folder.")
|
|
||||||
raise FaceswapError(msg) from err
|
|
||||||
logger.warning("FFMPEG hung while attempting to obtain the frame count. "
|
|
||||||
"Retrying %s of %s", this_attempt + 1, attempts)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note that other than with the subprocess calls below, ffmpeg wont hang here.
|
|
||||||
# Worst case Python will stop/crash and ffmpeg will continue running until done.
|
|
||||||
|
|
||||||
nframes = nsecs = None
|
|
||||||
for line in reversed(out.splitlines()):
|
|
||||||
if not line.startswith(b"frame="):
|
|
||||||
continue
|
|
||||||
line = line.decode(errors="ignore")
|
|
||||||
logger.debug("frame line: '%s'", line)
|
|
||||||
idx = line.find("frame=")
|
|
||||||
if idx >= 0:
|
|
||||||
splitframes = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
|
|
||||||
nframes = int(splitframes)
|
|
||||||
idx = line.find("time=")
|
|
||||||
if idx >= 0:
|
|
||||||
splittime = line[idx:].split("=", 1)[-1].lstrip().split(" ", 1)[0].strip()
|
|
||||||
nsecs = convert_to_secs(*splittime.split(":"))
|
|
||||||
logger.debug("nframes: %s, nsecs: %s", nframes, nsecs)
|
|
||||||
return nframes, nsecs
|
|
||||||
|
|
||||||
raise RuntimeError("Could not get number of frames") # pragma: no cover
|
|
||||||
|
|
||||||
|
|
||||||
def backup_file(directory, filename):
|
def backup_file(directory, filename):
|
||||||
""" Backup a given file by appending .bk to the end """
|
""" Backup a given file by appending .bk to the end """
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||||
|
@ -348,80 +214,6 @@ def deprecation_warning(func_name, additional_info=None):
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
def rotate_landmarks(face, rotation_matrix):
|
|
||||||
# pylint:disable=c-extension-no-member
|
|
||||||
""" Rotate the landmarks and bounding box for faces
|
|
||||||
found in rotated images.
|
|
||||||
Pass in a DetectedFace object or Alignments dict """
|
|
||||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
|
||||||
logger.trace("Rotating landmarks: (rotation_matrix: %s, type(face): %s",
|
|
||||||
rotation_matrix, type(face))
|
|
||||||
rotated_landmarks = None
|
|
||||||
# Detected Face Object
|
|
||||||
if isinstance(face, DetectedFace):
|
|
||||||
bounding_box = [[face.x, face.y],
|
|
||||||
[face.x + face.w, face.y],
|
|
||||||
[face.x + face.w, face.y + face.h],
|
|
||||||
[face.x, face.y + face.h]]
|
|
||||||
landmarks = face.landmarks_xy
|
|
||||||
|
|
||||||
# Alignments Dict
|
|
||||||
elif isinstance(face, dict) and "x" in face:
|
|
||||||
bounding_box = [[face.get("x", 0), face.get("y", 0)],
|
|
||||||
[face.get("x", 0) + face.get("w", 0),
|
|
||||||
face.get("y", 0)],
|
|
||||||
[face.get("x", 0) + face.get("w", 0),
|
|
||||||
face.get("y", 0) + face.get("h", 0)],
|
|
||||||
[face.get("x", 0),
|
|
||||||
face.get("y", 0) + face.get("h", 0)]]
|
|
||||||
landmarks = face.get("landmarks_xy", list())
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported face type")
|
|
||||||
|
|
||||||
logger.trace("Original landmarks: %s", landmarks)
|
|
||||||
|
|
||||||
rotation_matrix = cv2.invertAffineTransform( # pylint:disable=no-member
|
|
||||||
rotation_matrix)
|
|
||||||
rotated = list()
|
|
||||||
for item in (bounding_box, landmarks):
|
|
||||||
if not item:
|
|
||||||
continue
|
|
||||||
points = np.array(item, np.int32)
|
|
||||||
points = np.expand_dims(points, axis=0)
|
|
||||||
transformed = cv2.transform(points, # pylint:disable=no-member
|
|
||||||
rotation_matrix).astype(np.int32)
|
|
||||||
rotated.append(transformed.squeeze())
|
|
||||||
|
|
||||||
# Bounding box should follow x, y planes, so get min/max
|
|
||||||
# for non-90 degree rotations
|
|
||||||
pt_x = min([pnt[0] for pnt in rotated[0]])
|
|
||||||
pt_y = min([pnt[1] for pnt in rotated[0]])
|
|
||||||
pt_x1 = max([pnt[0] for pnt in rotated[0]])
|
|
||||||
pt_y1 = max([pnt[1] for pnt in rotated[0]])
|
|
||||||
width = pt_x1 - pt_x
|
|
||||||
height = pt_y1 - pt_y
|
|
||||||
|
|
||||||
if isinstance(face, DetectedFace):
|
|
||||||
face.x = int(pt_x)
|
|
||||||
face.y = int(pt_y)
|
|
||||||
face.w = int(width)
|
|
||||||
face.h = int(height)
|
|
||||||
face.r = 0
|
|
||||||
if len(rotated) > 1:
|
|
||||||
rotated_landmarks = [tuple(point) for point in rotated[1].tolist()]
|
|
||||||
face.landmarks_xy = rotated_landmarks
|
|
||||||
else:
|
|
||||||
face["left"] = int(pt_x)
|
|
||||||
face["top"] = int(pt_y)
|
|
||||||
face["right"] = int(pt_x1)
|
|
||||||
face["bottom"] = int(pt_y1)
|
|
||||||
rotated_landmarks = face
|
|
||||||
|
|
||||||
logger.trace("Rotated landmarks: %s", rotated_landmarks)
|
|
||||||
return face
|
|
||||||
|
|
||||||
|
|
||||||
def camel_case_split(identifier):
|
def camel_case_split(identifier):
|
||||||
""" Split a camel case name
|
""" Split a camel case name
|
||||||
from: https://stackoverflow.com/questions/29916065 """
|
from: https://stackoverflow.com/questions/29916065 """
|
||||||
|
|
|
@ -18,8 +18,7 @@ To get a :class:`~lib.faces_detect.DetectedFace` object use the function:
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace, rotate_landmarks
|
||||||
from lib.utils import rotate_landmarks
|
|
||||||
from plugins.extract._base import Extractor, logger
|
from plugins.extract._base import Extractor, logger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from concurrent import futures
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
|
|
||||||
import keras
|
import keras
|
||||||
|
@ -24,7 +25,6 @@ from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask
|
||||||
generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur)
|
generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur)
|
||||||
from lib.model.nn_blocks import NNBlocks
|
from lib.model.nn_blocks import NNBlocks
|
||||||
from lib.model.optimizers import Adam
|
from lib.model.optimizers import Adam
|
||||||
from lib.multithreading import MultiThread
|
|
||||||
from lib.utils import deprecation_warning, FaceswapError
|
from lib.utils import deprecation_warning, FaceswapError
|
||||||
from plugins.train._config import Config
|
from plugins.train._config import Config
|
||||||
|
|
||||||
|
@ -466,21 +466,13 @@ class ModelBase():
|
||||||
backup_func = self.backup.backup_model if self.should_backup(save_averages) else None
|
backup_func = self.backup.backup_model if self.should_backup(save_averages) else None
|
||||||
if backup_func:
|
if backup_func:
|
||||||
logger.info("Backing up models...")
|
logger.info("Backing up models...")
|
||||||
save_threads = list()
|
executor = futures.ThreadPoolExecutor()
|
||||||
for network in self.networks.values():
|
save_threads = [executor.submit(network.save, backup_func=backup_func)
|
||||||
name = "save_{}".format(network.name)
|
for network in self.networks.values()]
|
||||||
save_threads.append(MultiThread(network.save,
|
save_threads.append(executor.submit(self.state.save, backup_func=backup_func))
|
||||||
name=name,
|
futures.wait(save_threads)
|
||||||
backup_func=backup_func))
|
# call result() to capture errors
|
||||||
save_threads.append(MultiThread(self.state.save,
|
_ = [thread.result() for thread in save_threads]
|
||||||
name="save_state",
|
|
||||||
backup_func=backup_func))
|
|
||||||
for thread in save_threads:
|
|
||||||
thread.start()
|
|
||||||
for thread in save_threads:
|
|
||||||
if thread.has_error:
|
|
||||||
logger.error(thread.errors[0])
|
|
||||||
thread.join()
|
|
||||||
msg = "[Saved models]"
|
msg = "[Saved models]"
|
||||||
if save_averages:
|
if save_averages:
|
||||||
lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0],
|
lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0],
|
||||||
|
|
|
@ -33,7 +33,7 @@ from tensorflow.python import errors_impl as tf_errors # pylint:disable=no-name
|
||||||
|
|
||||||
from lib.alignments import Alignments
|
from lib.alignments import Alignments
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
from lib.training_data import TrainingDataGenerator, stack_images
|
from lib.training_data import TrainingDataGenerator
|
||||||
from lib.utils import FaceswapError, get_folder, get_image_paths
|
from lib.utils import FaceswapError, get_folder, get_image_paths
|
||||||
from plugins.train._config import Config
|
from plugins.train._config import Config
|
||||||
|
|
||||||
|
@ -292,10 +292,10 @@ class Batcher():
|
||||||
""" Return the next batch from the generator
|
""" Return the next batch from the generator
|
||||||
Items should come out as: (warped, target [, mask]) """
|
Items should come out as: (warped, target [, mask]) """
|
||||||
batch = next(self.feed)
|
batch = next(self.feed)
|
||||||
feed = batch[1]
|
if self.use_mask:
|
||||||
batch = batch[2:] # Remove full size samples and feed from batch
|
batch = [[batch["feed"], batch["masks"]], batch["targets"] + [batch["masks"]]]
|
||||||
mask = batch[-1]
|
else:
|
||||||
batch = [[feed, mask], batch] if self.use_mask else [feed, batch]
|
batch = [batch["feed"], batch["targets"]]
|
||||||
self.generate_preview(do_preview)
|
self.generate_preview(do_preview)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -309,13 +309,10 @@ class Batcher():
|
||||||
if self.preview_feed is None:
|
if self.preview_feed is None:
|
||||||
self.set_preview_feed()
|
self.set_preview_feed()
|
||||||
batch = next(self.preview_feed)
|
batch = next(self.preview_feed)
|
||||||
self.samples, feed = batch[:2]
|
self.samples = batch["samples"]
|
||||||
batch = batch[2:] # Remove full size samples and feed from batch
|
self.target = [batch["targets"][self.model.largest_face_index]]
|
||||||
self.target = batch[self.model.largest_face_index]
|
|
||||||
if self.use_mask:
|
if self.use_mask:
|
||||||
mask = batch[-1]
|
self.target += [batch["masks"]]
|
||||||
batch = [[feed, mask], batch]
|
|
||||||
self.target = [self.target, mask]
|
|
||||||
|
|
||||||
def set_preview_feed(self):
|
def set_preview_feed(self):
|
||||||
""" Set the preview dictionary """
|
""" Set the preview dictionary """
|
||||||
|
@ -347,15 +344,11 @@ class Batcher():
|
||||||
def compile_timelapse_sample(self):
|
def compile_timelapse_sample(self):
|
||||||
""" Timelapse samples """
|
""" Timelapse samples """
|
||||||
batch = next(self.timelapse_feed)
|
batch = next(self.timelapse_feed)
|
||||||
samples, feed = batch[:2]
|
batchsize = len(batch["samples"])
|
||||||
batchsize = len(samples)
|
images = [batch["targets"][self.model.largest_face_index]]
|
||||||
batch = batch[2:] # Remove full size samples and feed from batch
|
|
||||||
images = batch[self.model.largest_face_index]
|
|
||||||
if self.use_mask:
|
if self.use_mask:
|
||||||
mask = batch[-1]
|
images = images + [batch["masks"]]
|
||||||
batch = [[feed, mask], batch]
|
sample = self.compile_sample(batchsize, samples=batch["samples"], images=images)
|
||||||
images = [images, mask]
|
|
||||||
sample = self.compile_sample(batchsize, samples=samples, images=images)
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def set_timelapse_feed(self, images, batchsize):
|
def set_timelapse_feed(self, images, batchsize):
|
||||||
|
@ -405,10 +398,10 @@ class Samples():
|
||||||
|
|
||||||
for side, samples in self.images.items():
|
for side, samples in self.images.items():
|
||||||
other_side = "a" if side == "b" else "b"
|
other_side = "a" if side == "b" else "b"
|
||||||
predictions = [preds["{}_{}".format(side, side)],
|
predictions = [preds["{0}_{0}".format(side)],
|
||||||
preds["{}_{}".format(other_side, side)]]
|
preds["{}_{}".format(other_side, side)]]
|
||||||
display = self.to_full_frame(side, samples, predictions)
|
display = self.to_full_frame(side, samples, predictions)
|
||||||
headers[side] = self.get_headers(side, other_side, display[0].shape[1])
|
headers[side] = self.get_headers(side, display[0].shape[1])
|
||||||
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
|
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
|
||||||
if self.images[side][0].shape[0] % 2 == 1:
|
if self.images[side][0].shape[0] % 2 == 1:
|
||||||
figures[side] = np.concatenate([figures[side],
|
figures[side] = np.concatenate([figures[side],
|
||||||
|
@ -547,22 +540,22 @@ class Samples():
|
||||||
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
|
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def get_headers(self, side, other_side, width):
|
def get_headers(self, side, width):
|
||||||
""" Set headers for images """
|
""" Set headers for images """
|
||||||
logger.debug("side: '%s', other_side: '%s', width: %s",
|
logger.debug("side: '%s', width: %s",
|
||||||
side, other_side, width)
|
side, width)
|
||||||
|
titles = ("Original", "Swap") if side == "a" else ("Swap", "Original")
|
||||||
side = side.upper()
|
side = side.upper()
|
||||||
other_side = other_side.upper()
|
|
||||||
height = int(64 * self.scaling)
|
height = int(64 * self.scaling)
|
||||||
total_width = width * 3
|
total_width = width * 3
|
||||||
logger.debug("height: %s, total_width: %s", height, total_width)
|
logger.debug("height: %s, total_width: %s", height, total_width)
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
|
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
|
||||||
texts = ["Target {}".format(side),
|
texts = ["{} ({})".format(titles[0], side),
|
||||||
"{} > {}".format(side, side),
|
"{0} > {0}".format(titles[0]),
|
||||||
"{} > {}".format(side, other_side)]
|
"{} > {}".format(titles[0], titles[1])]
|
||||||
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
|
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
|
||||||
font,
|
font,
|
||||||
self.scaling,
|
self.scaling * 0.8,
|
||||||
1)[0]
|
1)[0]
|
||||||
for idx in range(len(texts))]
|
for idx in range(len(texts))]
|
||||||
text_y = int((height + text_sizes[0][1]) / 2)
|
text_y = int((height + text_sizes[0][1]) / 2)
|
||||||
|
@ -576,7 +569,7 @@ class Samples():
|
||||||
text,
|
text,
|
||||||
(text_x[idx], text_y),
|
(text_x[idx], text_y),
|
||||||
font,
|
font,
|
||||||
self.scaling,
|
self.scaling * 0.8,
|
||||||
(0, 0, 0),
|
(0, 0, 0),
|
||||||
1,
|
1,
|
||||||
lineType=cv2.LINE_AA) # pylint: disable=no-member
|
lineType=cv2.LINE_AA) # pylint: disable=no-member
|
||||||
|
@ -703,3 +696,25 @@ class Landmarks():
|
||||||
detected_face.load_aligned(None, size=self.size)
|
detected_face.load_aligned(None, size=self.size)
|
||||||
landmarks[detected_face.hash] = detected_face.aligned_landmarks
|
landmarks[detected_face.hash] = detected_face.aligned_landmarks
|
||||||
return landmarks
|
return landmarks
|
||||||
|
|
||||||
|
|
||||||
|
def stack_images(images):
|
||||||
|
""" Stack images """
|
||||||
|
logger.debug("Stack images")
|
||||||
|
|
||||||
|
def get_transpose_axes(num):
|
||||||
|
if num % 2 == 0:
|
||||||
|
logger.debug("Even number of images to stack")
|
||||||
|
y_axes = list(range(1, num - 1, 2))
|
||||||
|
x_axes = list(range(0, num - 1, 2))
|
||||||
|
else:
|
||||||
|
logger.debug("Odd number of images to stack")
|
||||||
|
y_axes = list(range(0, num - 1, 2))
|
||||||
|
x_axes = list(range(1, num - 1, 2))
|
||||||
|
return y_axes, x_axes, [num - 1]
|
||||||
|
|
||||||
|
images_shape = np.array(images.shape)
|
||||||
|
new_axes = get_transpose_axes(len(images_shape))
|
||||||
|
new_shape = [np.prod(images_shape[x]) for x in new_axes]
|
||||||
|
logger.debug("Stacked images")
|
||||||
|
return np.transpose(images, axes=np.concatenate(new_axes)).reshape(new_shape)
|
||||||
|
|
|
@ -17,9 +17,10 @@ from lib import Serializer
|
||||||
from lib.convert import Converter
|
from lib.convert import Converter
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
from lib.gpu_stats import GPUStats
|
from lib.gpu_stats import GPUStats
|
||||||
|
from lib.image import read_image_hash
|
||||||
from lib.multithreading import MultiThread, total_cpus
|
from lib.multithreading import MultiThread, total_cpus
|
||||||
from lib.queue_manager import queue_manager
|
from lib.queue_manager import queue_manager
|
||||||
from lib.utils import FaceswapError, get_folder, get_image_paths, hash_image_file
|
from lib.utils import FaceswapError, get_folder, get_image_paths
|
||||||
from plugins.extract.pipeline import Extractor
|
from plugins.extract.pipeline import Extractor
|
||||||
from plugins.plugin_loader import PluginLoader
|
from plugins.plugin_loader import PluginLoader
|
||||||
|
|
||||||
|
@ -682,7 +683,7 @@ class OptionalActions():
|
||||||
file_list = [path for path in get_image_paths(input_aligned_dir)]
|
file_list = [path for path in get_image_paths(input_aligned_dir)]
|
||||||
logger.info("Getting Face Hashes for selected Aligned Images")
|
logger.info("Getting Face Hashes for selected Aligned Images")
|
||||||
for face in tqdm(file_list, desc="Hashing Faces"):
|
for face in tqdm(file_list, desc="Hashing Faces"):
|
||||||
face_hashes.append(hash_image_file(face))
|
face_hashes.append(read_image_hash(face))
|
||||||
logger.debug("Face Hashes: %s", (len(face_hashes)))
|
logger.debug("Face Hashes: %s", (len(face_hashes)))
|
||||||
if not face_hashes:
|
if not face_hashes:
|
||||||
raise FaceswapError("Aligned directory is empty, no faces will be converted!")
|
raise FaceswapError("Aligned directory is empty, no faces will be converted!")
|
||||||
|
@ -746,5 +747,5 @@ class Legacy():
|
||||||
continue
|
continue
|
||||||
hash_faces = all_faces[frame]
|
hash_faces = all_faces[frame]
|
||||||
for index, face_path in hash_faces.items():
|
for index, face_path in hash_faces.items():
|
||||||
hash_faces[index] = hash_image_file(face_path)
|
hash_faces[index] = read_image_hash(face_path)
|
||||||
self.alignments.add_face_hashes(frame, hash_faces)
|
self.alignments.add_face_hashes(frame, hash_faces)
|
||||||
|
|
|
@ -8,9 +8,10 @@ from pathlib import Path
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lib.image import encode_image_with_hash
|
||||||
from lib.multithreading import MultiThread
|
from lib.multithreading import MultiThread
|
||||||
from lib.queue_manager import queue_manager
|
from lib.queue_manager import queue_manager
|
||||||
from lib.utils import get_folder, hash_encode_image, deprecation_warning
|
from lib.utils import get_folder, deprecation_warning
|
||||||
from plugins.extract.pipeline import Extractor
|
from plugins.extract.pipeline import Extractor
|
||||||
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
|
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
|
||||||
|
|
||||||
|
@ -255,7 +256,7 @@ class Extract():
|
||||||
face = detected_face["face"]
|
face = detected_face["face"]
|
||||||
resized_face = face.aligned_face
|
resized_face = face.aligned_face
|
||||||
|
|
||||||
face.hash, img = hash_encode_image(resized_face, extension)
|
face.hash, img = encode_image_with_hash(resized_face, extension)
|
||||||
self.save_queue.put((out_filename, img))
|
self.save_queue.put((out_filename, img))
|
||||||
final_faces.append(face.to_alignment())
|
final_faces.append(face.to_alignment())
|
||||||
self.alignments.data[os.path.basename(filename)] = final_faces
|
self.alignments.data[os.path.basename(filename)] = final_faces
|
||||||
|
|
|
@ -16,8 +16,9 @@ import numpy as np
|
||||||
from lib.aligner import Extract as AlignerExtract
|
from lib.aligner import Extract as AlignerExtract
|
||||||
from lib.alignments import Alignments as AlignmentsBase
|
from lib.alignments import Alignments as AlignmentsBase
|
||||||
from lib.face_filter import FaceFilter as FilterFunc
|
from lib.face_filter import FaceFilter as FilterFunc
|
||||||
from lib.utils import (camel_case_split, count_frames_and_secs, cv2_read_img, get_folder,
|
from lib.image import count_frames_and_secs, read_image
|
||||||
get_image_paths, set_system_verbosity, _video_extensions)
|
from lib.utils import (camel_case_split, get_folder, get_image_paths, set_system_verbosity,
|
||||||
|
_video_extensions)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@ -183,7 +184,7 @@ class Images():
|
||||||
""" Load frames from disk """
|
""" Load frames from disk """
|
||||||
logger.debug("Input is separate Frames. Loading images")
|
logger.debug("Input is separate Frames. Loading images")
|
||||||
for filename in self.input_images:
|
for filename in self.input_images:
|
||||||
image = cv2_read_img(filename, raise_error=False)
|
image = read_image(filename, raise_error=False)
|
||||||
if image is None:
|
if image is None:
|
||||||
continue
|
continue
|
||||||
yield filename, image
|
yield filename, image
|
||||||
|
@ -212,7 +213,7 @@ class Images():
|
||||||
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
|
logger.trace("Extracted frame_no %s from filename '%s'", frame_no, filename)
|
||||||
retval = self.load_one_video_frame(int(frame_no))
|
retval = self.load_one_video_frame(int(frame_no))
|
||||||
else:
|
else:
|
||||||
retval = cv2_read_img(filename, raise_error=True)
|
retval = read_image(filename, raise_error=True)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def load_one_video_frame(self, frame_no):
|
def load_one_video_frame(self, frame_no):
|
||||||
|
|
|
@ -12,10 +12,11 @@ import cv2
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.backend.tensorflow_backend import set_session
|
from keras.backend.tensorflow_backend import set_session
|
||||||
|
|
||||||
|
from lib.image import read_image
|
||||||
from lib.keypress import KBHit
|
from lib.keypress import KBHit
|
||||||
from lib.multithreading import MultiThread
|
from lib.multithreading import MultiThread
|
||||||
from lib.queue_manager import queue_manager
|
from lib.queue_manager import queue_manager # noqa pylint:disable=unused-import
|
||||||
from lib.utils import cv2_read_img, get_folder, get_image_paths, set_system_verbosity
|
from lib.utils import get_folder, get_image_paths, set_system_verbosity
|
||||||
from plugins.plugin_loader import PluginLoader
|
from plugins.plugin_loader import PluginLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -176,7 +177,7 @@ class Train():
|
||||||
@property
|
@property
|
||||||
def image_size(self):
|
def image_size(self):
|
||||||
""" Get the training set image size for storing in model data """
|
""" Get the training set image size for storing in model data """
|
||||||
image = cv2_read_img(self.images["a"][0], raise_error=True)
|
image = read_image(self.images["a"][0], raise_error=True)
|
||||||
size = image.shape[0]
|
size = image.shape[0]
|
||||||
logger.debug("Training image size: %s", size)
|
logger.debug("Training image size: %s", size)
|
||||||
return size
|
return size
|
||||||
|
|
|
@ -14,8 +14,8 @@ from tqdm import tqdm
|
||||||
from lib.aligner import Extract as AlignerExtract
|
from lib.aligner import Extract as AlignerExtract
|
||||||
from lib.alignments import Alignments
|
from lib.alignments import Alignments
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
from lib.utils import (_image_extensions, _video_extensions, count_frames_and_secs, cv2_read_img,
|
from lib.image import count_frames_and_secs, encode_image_with_hash, read_image, read_image_hash
|
||||||
hash_image_file, hash_encode_image)
|
from lib.utils import _image_extensions, _video_extensions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ class MediaLoader():
|
||||||
else:
|
else:
|
||||||
src = os.path.join(self.folder, filename)
|
src = os.path.join(self.folder, filename)
|
||||||
logger.trace("Loading image: '%s'", src)
|
logger.trace("Loading image: '%s'", src)
|
||||||
image = cv2_read_img(src, raise_error=True)
|
image = read_image(src, raise_error=True)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def load_video_frame(self, filename):
|
def load_video_frame(self, filename):
|
||||||
|
@ -210,7 +210,7 @@ class Faces(MediaLoader):
|
||||||
continue
|
continue
|
||||||
filename = os.path.splitext(face)[0]
|
filename = os.path.splitext(face)[0]
|
||||||
file_extension = os.path.splitext(face)[1]
|
file_extension = os.path.splitext(face)[1]
|
||||||
face_hash = hash_image_file(os.path.join(self.folder, face))
|
face_hash = read_image_hash(os.path.join(self.folder, face))
|
||||||
retval = {"face_fullname": face,
|
retval = {"face_fullname": face,
|
||||||
"face_name": filename,
|
"face_name": filename,
|
||||||
"face_extension": file_extension,
|
"face_extension": file_extension,
|
||||||
|
@ -358,7 +358,7 @@ class ExtractedFaces():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_face_with_hash(filename, extension, face):
|
def save_face_with_hash(filename, extension, face):
|
||||||
""" Save a face and return it's hash """
|
""" Save a face and return it's hash """
|
||||||
f_hash, img = hash_encode_image(face, extension)
|
f_hash, img = encode_image_with_hash(face, extension)
|
||||||
logger.trace("Saving face: '%s'", filename)
|
logger.trace("Saving face: '%s'", filename)
|
||||||
with open(filename, "wb") as out_file:
|
with open(filename, "wb") as out_file:
|
||||||
out_file.write(img)
|
out_file.write(img)
|
||||||
|
|
|
@ -16,8 +16,8 @@ from tqdm import tqdm
|
||||||
from lib.cli import FullHelpArgumentParser
|
from lib.cli import FullHelpArgumentParser
|
||||||
from lib import Serializer
|
from lib import Serializer
|
||||||
from lib.faces_detect import DetectedFace
|
from lib.faces_detect import DetectedFace
|
||||||
|
from lib.image import read_image
|
||||||
from lib.queue_manager import queue_manager
|
from lib.queue_manager import queue_manager
|
||||||
from lib.utils import cv2_read_img
|
|
||||||
from lib.vgg_face2_keras import VGGFace2 as VGGFace
|
from lib.vgg_face2_keras import VGGFace2 as VGGFace
|
||||||
from plugins.plugin_loader import PluginLoader
|
from plugins.plugin_loader import PluginLoader
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class Sort():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_landmarks(filename):
|
def get_landmarks(filename):
|
||||||
""" Extract the face from a frame (If not alignments file found) """
|
""" Extract the face from a frame (If not alignments file found) """
|
||||||
image = cv2_read_img(filename, raise_error=True)
|
image = read_image(filename, raise_error=True)
|
||||||
feed = Sort.alignment_dict(image)
|
feed = Sort.alignment_dict(image)
|
||||||
feed["filename"] = filename
|
feed["filename"] = filename
|
||||||
queue_manager.get_queue("in").put(feed)
|
queue_manager.get_queue("in").put(feed)
|
||||||
|
@ -161,7 +161,7 @@ class Sort():
|
||||||
logger.info("Sorting by face similarity...")
|
logger.info("Sorting by face similarity...")
|
||||||
|
|
||||||
images = np.array(self.find_images(input_dir))
|
images = np.array(self.find_images(input_dir))
|
||||||
preds = np.array([self.vgg_face.predict(cv2_read_img(img, raise_error=True))
|
preds = np.array([self.vgg_face.predict(read_image(img, raise_error=True))
|
||||||
for img in tqdm(images, desc="loading", file=sys.stdout)])
|
for img in tqdm(images, desc="loading", file=sys.stdout)])
|
||||||
logger.info("Sorting. Depending on ths size of your dataset, this may take a few "
|
logger.info("Sorting. Depending on ths size of your dataset, this may take a few "
|
||||||
"minutes...")
|
"minutes...")
|
||||||
|
@ -264,7 +264,7 @@ class Sort():
|
||||||
logger.info("Sorting by histogram similarity...")
|
logger.info("Sorting by histogram similarity...")
|
||||||
|
|
||||||
img_list = [
|
img_list = [
|
||||||
[img, cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
|
[img, cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])]
|
||||||
for img in
|
for img in
|
||||||
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
|
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
|
||||||
]
|
]
|
||||||
|
@ -294,7 +294,7 @@ class Sort():
|
||||||
|
|
||||||
img_list = [
|
img_list = [
|
||||||
[img,
|
[img,
|
||||||
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256]), 0]
|
cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256]), 0]
|
||||||
for img in
|
for img in
|
||||||
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
|
tqdm(self.find_images(input_dir), desc="Loading", file=sys.stdout)
|
||||||
]
|
]
|
||||||
|
@ -548,7 +548,7 @@ class Sort():
|
||||||
input_dir = self.args.input_dir
|
input_dir = self.args.input_dir
|
||||||
logger.info("Preparing to group...")
|
logger.info("Preparing to group...")
|
||||||
if group_method == 'group_blur':
|
if group_method == 'group_blur':
|
||||||
temp_list = [[img, self.estimate_blur(cv2_read_img(img, raise_error=True))]
|
temp_list = [[img, self.estimate_blur(read_image(img, raise_error=True))]
|
||||||
for img in
|
for img in
|
||||||
tqdm(self.find_images(input_dir),
|
tqdm(self.find_images(input_dir),
|
||||||
desc="Reloading",
|
desc="Reloading",
|
||||||
|
@ -576,7 +576,7 @@ class Sort():
|
||||||
elif group_method == 'group_hist':
|
elif group_method == 'group_hist':
|
||||||
temp_list = [
|
temp_list = [
|
||||||
[img,
|
[img,
|
||||||
cv2.calcHist([cv2_read_img(img, raise_error=True)], [0], None, [256], [0, 256])]
|
cv2.calcHist([read_image(img, raise_error=True)], [0], None, [256], [0, 256])]
|
||||||
for img in
|
for img in
|
||||||
tqdm(self.find_images(input_dir),
|
tqdm(self.find_images(input_dir),
|
||||||
desc="Reloading",
|
desc="Reloading",
|
||||||
|
@ -632,7 +632,7 @@ class Sort():
|
||||||
Estimate the amount of blur an image has with the variance of the Laplacian.
|
Estimate the amount of blur an image has with the variance of the Laplacian.
|
||||||
Normalize by pixel number to offset the effect of image size on pixel gradients & variance
|
Normalize by pixel number to offset the effect of image size on pixel gradients & variance
|
||||||
"""
|
"""
|
||||||
image = cv2_read_img(image_file, raise_error=True)
|
image = read_image(image_file, raise_error=True)
|
||||||
if image.ndim == 3:
|
if image.ndim == 3:
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
blur_map = cv2.Laplacian(image, cv2.CV_32F)
|
blur_map = cv2.Laplacian(image, cv2.CV_32F)
|
||||||
|
|
Loading…
Add table
Reference in a new issue