1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/lib/image.py
2024-04-03 14:03:54 +01:00

1603 lines
61 KiB
Python

#!/usr/bin python3
""" Utilities for working with images and videos """
from __future__ import annotations
import json
import logging
import re
import subprocess
import os
import struct
import sys
import typing as T
from ast import literal_eval
from bisect import bisect
from concurrent import futures
from zlib import crc32
import cv2
import imageio
import imageio_ffmpeg as im_ffm
import numpy as np
from tqdm import tqdm
from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths
if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderDict
logger = logging.getLogger(__name__)
# ################### #
# <<< IMAGE UTILS >>> #
# ################### #
# <<< IMAGE IO >>> #
class FfmpegReader(imageio.plugins.ffmpeg.FfmpegFormat.Reader): # type:ignore
""" Monkey patch imageio ffmpeg to use keyframes whilst seeking """
def __init__(self, format, request):
super().__init__(format, request)
self._frame_pts = None
self._keyframes = None
self.use_patch = False
def get_frame_info(self, frame_pts=None, keyframes=None):
""" Store the source video's keyframes in :attr:`_frame_info" for the current video for use
in :func:`initialize`.
Parameters
----------
frame_pts: list, optional
A list corresponding to the video frame count of the pts_time per frame. If this and
`keyframes` are provided, then analyzing the video is skipped and the values from the
given lists are used. Default: ``None``
keyframes: list, optional
A list containing the frame numbers of each key frame. if this and `frame_pts` are
provided, then analyzing the video is skipped and the values from the given lists are
used. Default: ``None``
"""
if frame_pts is not None and keyframes is not None:
logger.debug("Video meta information provided. Not analyzing video")
self._frame_pts = frame_pts
self._keyframes = keyframes
return len(frame_pts), dict(pts_time=self._frame_pts, keyframes=self._keyframes)
assert isinstance(self._filename, str), "Video path must be a string"
# NB: The below video filter applies the detected frame rate prior to showinfo. This
# appears to help prevent an issue where the number of timestamp entries generated by
# showinfo does not correspond to the number of frames that the video file generates.
# This is because the demuxer will duplicate frames to meet the required frame rate.
# This **may** cause issues so be aware.
# Also, drop frame rates (i.e 23.98, 29.97 and 59.94) will introduce rounding errors which
# means sync will drift on generated pts. These **should** be the only 'drop-frame rates'
# that appear in video files, but this is video files, and nothing is guaranteed.
# (The actual values for these should be 24000/1001, 30000/1001 and 60000/1001
# respectively). The solutions to round these values is hacky at best, so:
# TODO find a more robust method for extracting/handling drop-frame rates.
fps = self._meta["fps"]
rounded_fps = round(fps, 0)
if 0.01 < rounded_fps - fps < 0.10: # 0.90 - 0.99
new_fps = f"{int(rounded_fps * 1000)}/1001"
logger.debug("Adjusting drop-frame fps: %s to %s", fps, new_fps)
fps = new_fps
cmd = [im_ffm.get_ffmpeg_exe(),
"-hide_banner",
"-copyts",
"-i", self._filename,
"-vf", f"fps=fps={fps},showinfo",
"-start_number", "0",
"-an",
"-f", "null",
"-"]
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
process = subprocess.Popen(cmd,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
universal_newlines=True)
frame_pts = []
key_frames = []
last_update = 0
pbar = tqdm(desc="Analyzing Video",
leave=False,
total=int(self._meta["duration"]),
unit="secs")
while True:
output = process.stdout.readline().strip()
if output == "" and process.poll() is not None:
break
if "iskey" not in output:
continue
logger.trace("Keyframe line: %s", output)
line = re.split(r"\s+|:\s*", output)
pts_time = float(line[line.index("pts_time") + 1])
frame_no = int(line[line.index("n") + 1])
frame_pts.append(pts_time)
if "iskey:1" in output:
key_frames.append(frame_no)
logger.trace("pts_time: %s, frame_no: %s", pts_time, frame_no)
if int(pts_time) == last_update:
# Floating points make TQDM display poorly, so only update on full
# second increments
continue
pbar.update(int(pts_time) - last_update)
last_update = int(pts_time)
pbar.close()
return_code = process.poll()
frame_count = len(frame_pts)
logger.debug("Return code: %s, frame_pts: %s, keyframes: %s, frame_count: %s",
return_code, frame_pts, key_frames, frame_count)
self._frame_pts = frame_pts
self._keyframes = key_frames
return frame_count, dict(pts_time=self._frame_pts, keyframes=self._keyframes)
def _previous_keyframe_info(self, index=0):
""" Return the previous keyframe's pts_time and frame number """
prev_keyframe_idx = bisect(self._keyframes, index) - 1
prev_keyframe = self._keyframes[prev_keyframe_idx]
prev_pts_time = self._frame_pts[prev_keyframe]
logger.trace("keyframe pts_time: %s, keyframe: %s", prev_pts_time, prev_keyframe)
return prev_pts_time, prev_keyframe
def _initialize(self, index=0): # noqa:C901
""" Replace ImageIO _initialize with a version that explictly uses keyframes.
Notes
-----
This introduces a minor change by seeking fast to the previous keyframe and then discarding
subsequent frames until the desired frame is reached. In testing, setting -ss flag either
prior to input, or both prior (fast) and after (slow) would not always bring back the
correct frame for all videos. Navigating to the previous keyframe then discarding frames
until the correct frame is reached appears to work well.
"""
# pylint: disable-all
if self._read_gen is not None:
self._read_gen.close()
iargs = []
oargs = []
skip_frames = 0
# Create input args
iargs += self._arg_input_params
if self.request._video:
iargs += ["-f", CAM_FORMAT] # noqa
if self._arg_pixelformat:
iargs += ["-pix_fmt", self._arg_pixelformat]
if self._arg_size:
iargs += ["-s", self._arg_size]
elif index > 0: # re-initialize / seek
# Note: only works if we initialized earlier, and now have meta. Some info here:
# https://trac.ffmpeg.org/wiki/Seeking
# There are two ways to seek, one before -i (input_params) and after (output_params).
# The former is fast, because it uses keyframes, the latter is slow but accurate.
# According to the article above, the fast method should also be accurate from ffmpeg
# version 2.1, however in version 4.1 our tests start failing again. Not sure why, but
# we can solve this by combining slow and fast.
# Further note: The old method would go back 10 seconds and then seek slow. This was
# still somewhat unresponsive and did not always land on the correct frame. This monkey
# patched version goes to the previous keyframe then discards frames until the correct
# frame is landed on.
if self.use_patch and self._frame_pts is None:
self.get_frame_info()
if self.use_patch:
keyframe_pts, keyframe = self._previous_keyframe_info(index)
seek_fast = keyframe_pts
skip_frames = index - keyframe
else:
starttime = index / self._meta["fps"]
seek_slow = min(10, starttime)
seek_fast = starttime - seek_slow
# We used to have this epsilon earlier, when we did not use
# the slow seek. I don't think we need it anymore.
# epsilon = -1 / self._meta["fps"] * 0.1
iargs += ["-ss", "%.06f" % (seek_fast)]
if not self.use_patch:
oargs += ["-ss", "%.06f" % (seek_slow)]
# Output args, for writing to pipe
if self._arg_size:
oargs += ["-s", self._arg_size]
if self.request.kwargs.get("fps", None):
fps = float(self.request.kwargs["fps"])
oargs += ["-r", "%.02f" % fps]
oargs += self._arg_output_params
# Get pixelformat and bytes per pixel
pix_fmt = self._pix_fmt
bpp = self._depth * self._bytes_per_channel
# Create generator
rf = self._ffmpeg_api.read_frames
self._read_gen = rf(
self._filename, pix_fmt, bpp, input_params=iargs, output_params=oargs
)
# Read meta data. This start the generator (and ffmpeg subprocess)
if self.request._video:
# With cameras, catch error and turn into IndexError
try:
meta = self._read_gen.__next__()
except IOError as err:
err_text = str(err)
if "darwin" in sys.platform:
if "Unknown input format: 'avfoundation'" in err_text:
err_text += (
"Try installing FFMPEG using "
"home brew to get a version with "
"support for cameras."
)
raise IndexError(
"No camera at {}.\n\n{}".format(self.request._video, err_text)
)
else:
self._meta.update(meta)
elif index == 0:
self._meta.update(self._read_gen.__next__())
else:
if self.use_patch:
frames_skipped = 0
while skip_frames != frames_skipped:
# Skip frames that are not the desired frame
_ = self._read_gen.__next__()
frames_skipped += 1
self._read_gen.__next__() # we already have meta data
imageio.plugins.ffmpeg.FfmpegFormat.Reader = FfmpegReader # type: ignore
def read_image(filename, raise_error=False, with_metadata=False):
""" Read an image file from a file location.
Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually
loaded. Errors can be logged and ignored so that the process can continue on an image load
failure.
Parameters
----------
filename: str
Full path to the image to be loaded.
raise_error: bool, optional
If ``True`` then any failures (including the returned image being ``None``) will be
raised. If ``False`` then an error message will be logged, but the error will not be
raised. Default: ``False``
with_metadata: bool, optional
Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then
returns the Faceswap metadata stored with in a Face images .png exif header.
Default: ``False``
Returns
-------
numpy.ndarray or tuple
If :attr:`with_metadata` is ``False`` then returns a `numpy.ndarray` of the image in `BGR`
channel order. If :attr:`with_metadata` is ``True`` then returns a `tuple` of
(`numpy.ndarray`" of the image in `BGR`, `dict` of face's Faceswap metadata)
Example
-------
>>> image_file = "/path/to/image.png"
>>> try:
>>> image = read_image(image_file, raise_error=True, with_metadata=False)
>>> except:
>>> raise ValueError("There was an error")
"""
logger.trace("Requested image: '%s'", filename)
success = True
image = None
try:
with open(filename, "rb") as infile:
raw_file = infile.read()
image = cv2.imdecode(np.frombuffer(raw_file, dtype="uint8"), cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Image is None")
if with_metadata:
metadata = png_read_meta(raw_file)
retval = (image, metadata)
else:
retval = image
except TypeError as err:
success = False
msg = "Error while reading image (TypeError): '{}'".format(filename)
msg += ". Original error message: {}".format(str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
except ValueError as err:
success = False
msg = ("Error while reading image. This can be caused by special characters in the "
"filename or a corrupt image file: '{}'".format(filename))
msg += ". Original error message: {}".format(str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
except Exception as err: # pylint:disable=broad-except
success = False
msg = "Failed to load image '{}'. Original Error: {}".format(filename, str(err))
logger.error(msg)
if raise_error:
raise Exception(msg)
logger.trace("Loaded image: '%s'. Success: %s", filename, success)
return retval
def read_image_batch(filenames, with_metadata=False):
""" Load a batch of images from the given file locations.
Leverages multi-threading to load multiple images from disk at the same time leading to vastly
reduced image read times.
Parameters
----------
filenames: list
A list of ``str`` full paths to the images to be loaded.
with_metadata: bool, optional
Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then
returns the Faceswap metadata stored with in a Face images .png exif header.
Default: ``False``
Returns
-------
numpy.ndarray
The batch of images in `BGR` channel order returned in the order of :attr:`filenames`
Notes
-----
As the images are compiled into a batch, they must be all of the same dimensions.
Example
-------
>>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"]
>>> images = read_image_batch(image_filenames)
"""
logger.trace("Requested batch: '%s'", filenames)
batch = [None for _ in range(len(filenames))]
if with_metadata:
meta = [None for _ in range(len(filenames))]
with futures.ThreadPoolExecutor() as executor:
images = {executor.submit(read_image, filename,
raise_error=True, with_metadata=with_metadata): idx
for idx, filename in enumerate(filenames)}
for future in futures.as_completed(images):
ret_idx = images[future]
if with_metadata:
batch[ret_idx], meta[ret_idx] = future.result()
else:
batch[ret_idx] = future.result()
batch = np.array(batch)
retval = (batch, meta) if with_metadata else batch
logger.trace("Returning images: (filenames: %s, batch shape: %s, with_metadata: %s)",
filenames, batch.shape, with_metadata)
return retval
def read_image_meta(filename):
""" Read the Faceswap metadata stored in an extracted face's exif header.
Parameters
----------
filename: str
Full path to the image to be retrieve the meta information for.
Returns
-------
dict
The output dictionary will contain the `width` and `height` of the png image as well as any
`itxt` information.
Example
-------
>>> image_file = "/path/to/image.png"
>>> metadata = read_image_meta(image_file)
>>> width = metadata["width]
>>> height = metadata["height"]
>>> faceswap_info = metadata["itxt"]
"""
retval = dict()
if os.path.splitext(filename)[-1].lower() != ".png":
# Get the dimensions directly from the image for non-pngs
logger.trace("Non png found. Loading file for dimensions: '%s'", filename)
img = cv2.imread(filename)
retval["height"], retval["width"] = img.shape[:2]
return retval
with open(filename, "rb") as infile:
try:
chunk = infile.read(8)
except PermissionError:
raise PermissionError(f"PermissionError while reading: {filename}")
if chunk != b"\x89PNG\r\n\x1a\n":
raise ValueError(f"Invalid header found in png: {filename}")
while True:
chunk = infile.read(8)
length, field = struct.unpack(">I4s", chunk)
logger.trace("Read chunk: (chunk: %s, length: %s, field: %s", chunk, length, field)
if not chunk or field == b"IDAT":
break
if field == b"IHDR":
# Get dimensions
chunk = infile.read(8)
retval["width"], retval["height"] = struct.unpack(">II", chunk)
length -= 8
elif field == b"iTXt":
keyword, value = infile.read(length).split(b"\0", 1)
if keyword == b"faceswap":
retval["itxt"] = literal_eval(value[4:].decode("utf-8", errors="replace"))
break
else:
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1",
errors="ignore"))
length = 0 # Reset marker for next chunk
infile.seek(length + 4, 1)
logger.trace("filename: %s, metadata: %s", filename, retval)
return retval
def read_image_meta_batch(filenames):
""" Read the Faceswap metadata stored in a batch extracted faces' exif headers.
Leverages multi-threading to load multiple images from disk at the same time
leading to vastly reduced image read times. Creates a generator to retrieve filenames
with their metadata as they are calculated.
Notes
-----
The order of returned values is non-deterministic so will most likely not be returned in the
same order as the filenames
Parameters
----------
filenames: list
A list of ``str`` full paths to the images to be loaded.
Yields
-------
tuple
(**filename** (`str`), **metadata** (`dict`) )
Example
-------
>>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"]
>>> for filename, meta in read_image_meta_batch(image_filenames):
>>> <do something>
"""
logger.trace("Requested batch: '%s'", filenames)
executor = futures.ThreadPoolExecutor()
with executor:
logger.debug("Submitting %s items to executor", len(filenames))
read_meta = {executor.submit(read_image_meta, filename): filename
for filename in filenames}
logger.debug("Succesfully submitted %s items to executor", len(filenames))
for future in futures.as_completed(read_meta):
retval = (read_meta[future], future.result())
logger.trace("Yielding: %s", retval)
yield retval
def pack_to_itxt(metadata):
""" Pack the given metadata dictionary to a PNG iTXt header field.
Parameters
----------
metadata: dict or bytes
The dictionary to write to the header. Can be pre-encoded as utf-8.
Returns
-------
bytes
A byte encoded PNG iTXt field, including chunk header and CRC
"""
if not isinstance(metadata, bytes):
metadata = str(metadata).encode("utf-8", "strict")
key = "faceswap".encode("latin-1", "strict")
chunk = key + b"\0\0\0\0\0" + metadata
crc = struct.pack(">I", crc32(chunk, crc32(b"iTXt")) & 0xFFFFFFFF)
length = struct.pack(">I", len(chunk))
retval = length + b"iTXt" + chunk + crc
return retval
def update_existing_metadata(filename, metadata):
""" Update the png header metadata for an existing .png extracted face file on the filesystem.
Parameters
----------
filename: str
The full path to the face to be updated
metadata: dict or bytes
The dictionary to write to the header. Can be pre-encoded as utf-8.
"""
tmp_filename = filename + "~"
with open(filename, "rb") as png, open(tmp_filename, "wb") as tmp:
chunk = png.read(8)
if chunk != b"\x89PNG\r\n\x1a\n":
raise ValueError(f"Invalid header found in png: {filename}")
tmp.write(chunk)
while True:
chunk = png.read(8)
length, field = struct.unpack(">I4s", chunk)
logger.trace("Read chunk: (chunk: %s, length: %s, field: %s)", chunk, length, field)
if field == b"IDAT": # Write out all remaining data
logger.trace("Writing image data and closing png")
tmp.write(chunk + png.read())
break
if field != b"iTXt": # Write non iTXt chunk straight out
logger.trace("Copying existing chunk")
tmp.write(chunk + png.read(length + 4)) # Header + CRC
continue
keyword, value = png.read(length).split(b"\0", 1)
if keyword != b"faceswap":
# Write existing non fs-iTXt data + CRC
logger.trace("Copying non-faceswap iTXt chunk: %s", keyword)
tmp.write(keyword + b"\0" + value + png.read(4))
continue
logger.trace("Updating faceswap iTXt chunk")
tmp.write(pack_to_itxt(metadata))
png.seek(4, 1) # Skip old CRC
os.replace(tmp_filename, filename)
def encode_image(image: np.ndarray,
extension: str,
encoding_args: tuple[int, ...] | None = None,
metadata: PNGHeaderDict | dict[str, T.Any] | bytes | None = None) -> bytes:
""" Encode an image.
Parameters
----------
image: numpy.ndarray
The image to be encoded in `BGR` channel order.
extension: str
A compatible `cv2` image file extension that the final image is to be saved to.
encoding_args: tuple[int, ...], optional
Any encoding arguments to pass to cv2's imencode function
metadata: dict or bytes, optional
Metadata for the image. If provided, and the extension is png or tiff, this information
will be written to the PNG itxt header. Default:``None`` Can be provided as a python dict
or pre-encoded
Returns
-------
encoded_image: bytes
The image encoded into the correct file format
Example
-------
>>> image_file = "/path/to/image.png"
>>> image = read_image(image_file)
>>> encoded_image = encode_image(image, ".jpg")
"""
if metadata and extension.lower() not in (".png", ".tif"):
raise ValueError("Metadata is only supported for .png and .tif images")
args = tuple() if encoding_args is None else encoding_args
retval = cv2.imencode(extension, image, args)[1]
if metadata:
func = {".png": png_write_meta, ".tif": tiff_write_meta}[extension]
retval = func(retval.tobytes(), metadata) # type:ignore[arg-type]
return retval
def png_write_meta(image: bytes, data: PNGHeaderDict | dict[str, T.Any] | bytes) -> bytes:
""" Write Faceswap information to a png's iTXt field.
Parameters
----------
image: bytes
The bytes encoded png file to write header data to
data: dict or bytes
The dictionary to write to the header. Can be pre-encoded as utf-8.
Notes
-----
This is a fairly stripped down and non-robust header writer to fit a very specific task. OpenCV
will not write any iTXt headers to the PNG file, so we make the assumption that the only iTXt
header that exists is the one that we created for storing alignments.
References
----------
PNG Specification: https://www.w3.org/TR/2003/REC-PNG-20031110/
"""
split = image.find(b"IDAT") - 4
retval = image[:split] + pack_to_itxt(data) + image[split:]
return retval
def tiff_write_meta(image: bytes, data: dict[str, T.Any] | bytes) -> bytes:
""" Write Faceswap information to a tiff's image_description field.
Parameters
----------
png: bytes
The bytes encoded tiff file to write header data to
data: dict or bytes
The data to write to the image-description field. If provided as a dict, then it should be
a json serializable object, otherwise it should be data encoded as ascii bytes
Notes
-----
This handles a very specific task of adding, and populating, an ImageDescription field in a
Tiff file generated by OpenCV. For any other usecases it will likely fail
"""
if not isinstance(data, bytes):
data = json.dumps(data, ensure_ascii=True).encode("ascii")
assert image[:2] == b"II", "Not a supported TIFF file"
assert struct.unpack("<H", image[2:4])[0] == 42, "Only version 42 Tiff files are supported"
ptr = struct.unpack("<I", image[4:8])[0]
rendered = image[:ptr] # Pack up to IFD
num_tags = struct.unpack("<H", image[ptr: ptr + 2])[0]
ptr += 2
rendered += struct.pack("<H", num_tags + 1) # Pack new IFD field count
remainder = image[ptr + num_tags * 12:] # Hold the data from after the IFD
assert struct.unpack("<I", remainder[:4])[0] == 0, "Multi-page TIFF files not supported"
dtypes = {2: "1s", 3: "1H", 4: "1I", 7: '1B'}
ifd = b""
insert_idx = -1
for i in range(num_tags):
tag = image[ptr + i * 12:ptr + (1 + i) * 12]
tag_id = struct.unpack("<H", tag[0:2])[0]
assert tag_id != 270, "Not a supported TIFF file"
tag_count = struct.unpack("<I", tag[4:8])[0]
tag_type = dtypes[struct.unpack("<H", tag[2:4])[0]]
size = tag_count * struct.calcsize(tag_type)
if insert_idx < 0 and tag_id > 270:
insert_idx = i # Log insert location of image description
if size <= 4: # value in offset column
ifd += tag
continue
ifd += tag[:8]
tag_offset = struct.unpack("<I", tag[8:12])[0]
new_offset = struct.pack("<I", tag_offset + 12) # Increment by length of new ifd entry
ifd += new_offset
end = len(rendered) + len(ifd) + 12 + len(remainder)
desc = struct.pack("HH", 270, 2)
desc += struct.pack("II", len(data), end)
# TODO confirm no extra pages in end of IFD
rendered += ifd[:insert_idx * 12] + desc + ifd[insert_idx * 12:] + remainder + data
return rendered
def tiff_read_meta(image: bytes) -> dict[str, T.Any]:
""" Read information stored in a Tiff's Image Description field """
assert image[:2] == b"II", "Not a supported TIFF file"
assert struct.unpack("<H", image[2:4])[0] == 42, "Only version 42 Tiff files are supported"
ptr = struct.unpack("<I", image[4:8])[0]
num_tags = struct.unpack("<H", image[ptr: ptr + 2])[0]
ptr += 2
ifd_end = ptr + num_tags * 12
ifd = image[ptr: ifd_end]
next_ifd = struct.unpack("<I", image[ifd_end:ifd_end + 4])[0]
assert next_ifd == 0, "Multi-page TIFF files not supported"
dtypes = {2: "1s", 3: "1H", 4: "1I", 7: '1B'}
data = None
for i in range(num_tags):
tag = ifd[i * 12:(1 + i) * 12]
tag_id = struct.unpack("<H", tag[0:2])[0]
if tag_id != 270:
continue
tag_count = struct.unpack("<I", tag[4:8])[0]
tag_type = dtypes[struct.unpack("<H", tag[2:4])[0]]
size = tag_count * struct.calcsize(tag_type)
tag_offset = struct.unpack("<I", tag[8:12])[0]
data = image[tag_offset: tag_offset + size]
assert data is not None, "No Metadata found in Tiff File"
retval = json.loads(data.decode("ascii"))
return retval
def png_read_meta(image):
""" Read the Faceswap information stored in a png's iTXt field.
Parameters
----------
image: bytes
The bytes encoded png file to read header data from
Returns
-------
dict
The Faceswap information stored in the PNG header
Notes
-----
This is a very stripped down, non-robust and non-secure header reader to fit a very specific
task. OpenCV will not write any iTXt headers to the PNG file, so we make the assumption that
the only iTXt header that exists is the one that Faceswap created for storing alignments.
"""
retval = None
pointer = 0
while True:
pointer = image.find(b"iTXt", pointer) - 4
if pointer < 0:
logger.trace("No metadata in png")
break
length = struct.unpack(">I", image[pointer:pointer + 4])[0]
pointer += 8
keyword, value = image[pointer:pointer + length].split(b"\0", 1)
if keyword == b"faceswap":
retval = literal_eval(value[4:].decode("utf-8", errors="ignore"))
break
logger.trace("Skipping iTXt chunk: '%s'", keyword.decode("latin-1", errors="ignore"))
pointer += length + 4
return retval
def generate_thumbnail(image, size=96, quality=60):
""" Generate a jpg thumbnail for the given image.
Parameters
----------
image: :class:`numpy.ndarray`
Three channel BGR image to convert to a jpg thumbnail
size: int
The width and height, in pixels, that the thumbnail should be generated at
quality: int
The jpg quality setting to use
Returns
-------
:class:`numpy.ndarray`
The given image encoded to a jpg at the given size and quality settings
"""
logger.trace("Input shape: %s, size: %s, quality: %s", image.shape, size, quality)
orig_size = image.shape[0]
if orig_size != size:
interp = cv2.INTER_AREA if orig_size > size else cv2.INTER_CUBIC
image = cv2.resize(image, (size, size), interpolation=interp)
retval = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1]
logger.trace("Output shape: %s", retval.shape)
return retval
def batch_convert_color(batch, colorspace):
""" Convert a batch of images from one color space to another.
Converts a batch of images by reshaping the batch prior to conversion rather than iterating
over the images. This leads to a significant speed up in the convert process.
Parameters
----------
batch: numpy.ndarray
A batch of images.
colorspace: str
The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be
``'BGR2LAB'``.
See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full
list of color codes.
Returns
-------
numpy.ndarray
The batch converted to the requested color space.
Example
-------
>>> images_bgr = numpy.array([image1, image2, image3])
>>> images_lab = batch_convert_color(images_bgr, "BGR2LAB")
Notes
-----
This function is only compatible for color space conversions that have the same image shape
for source and destination color spaces.
If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some
information lost. For many cases, this will not be noticeable but it is recommended
to use 32-bit images in cases that need the full range of colors or that convert an image
before an operation and then convert back.
"""
logger.trace("Batch converting: (batch shape: %s, colorspace: %s)", batch.shape, colorspace)
original_shape = batch.shape
batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:]))
batch = cv2.cvtColor(batch, getattr(cv2, "COLOR_{}".format(colorspace)))
return batch.reshape(original_shape)
def hex_to_rgb(hexcode):
""" Convert a hex number to it's RGB counterpart.
Parameters
----------
hexcode: str
The hex code to convert (e.g. `"#0d25ac"`)
Returns
-------
tuple
The hex code as a 3 integer (`R`, `G`, `B`) tuple
"""
value = hexcode.lstrip("#")
chars = len(value)
return tuple(int(value[i:i + chars // 3], 16) for i in range(0, chars, chars // 3))
def rgb_to_hex(rgb):
""" Convert an RGB tuple to it's hex counterpart.
Parameters
----------
rgb: tuple
The (`R`, `G`, `B`) integer values to convert (e.g. `(0, 255, 255)`)
Returns
-------
str:
The 6 digit hex code with leading `#` applied
"""
return "#{:02x}{:02x}{:02x}".format(*rgb)
# ################### #
# <<< VIDEO UTILS >>> #
# ################### #
def count_frames(filename, fast=False):
""" Count the number of frames in a video file
There is no guaranteed accurate way to get a count of video frames without iterating through
a video and decoding every frame.
:func:`count_frames` can return an accurate count (albeit fairly slowly) or a possibly less
accurate count, depending on the :attr:`fast` parameter. A progress bar is displayed.
Parameters
----------
filename: str
Full path to the video to return the frame count from.
fast: bool, optional
Whether to count the frames without decoding them. This is significantly faster but
accuracy is not guaranteed. Default: ``False``.
Returns
-------
int:
The number of frames in the given video file.
Example
-------
>>> filename = "/path/to/video.mp4"
>>> frame_count = count_frames(filename)
"""
logger.debug("filename: %s, fast: %s", filename, fast)
assert isinstance(filename, str), "Video path must be a string"
cmd = [im_ffm.get_ffmpeg_exe(), "-i", filename, "-map", "0:v:0"]
if fast:
cmd.extend(["-c", "copy"])
cmd.extend(["-f", "null", "-"])
logger.debug("FFMPEG Command: '%s'", " ".join(cmd))
process = subprocess.Popen(cmd,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
universal_newlines=True, encoding="utf8")
pbar = None
duration = None
init_tqdm = False
update = 0
frames = 0
while True:
output = process.stdout.readline().strip()
if output == "" and process.poll() is not None:
break
if output.startswith("Duration:"):
logger.debug("Duration line: %s", output)
idx = output.find("Duration:") + len("Duration:")
duration = int(convert_to_secs(*output[idx:].split(",", 1)[0].strip().split(":")))
logger.debug("duration: %s", duration)
if output.startswith("frame="):
logger.debug("frame line: %s", output)
if not init_tqdm:
logger.debug("Initializing tqdm")
pbar = tqdm(desc="Analyzing Video", leave=False, total=duration, unit="secs")
init_tqdm = True
time_idx = output.find("time=") + len("time=")
frame_idx = output.find("frame=") + len("frame=")
frames = int(output[frame_idx:].strip().split(" ")[0].strip())
vid_time = int(convert_to_secs(*output[time_idx:].split(" ")[0].strip().split(":")))
logger.debug("frames: %s, vid_time: %s", frames, vid_time)
prev_update = update
update = vid_time
pbar.update(update - prev_update)
if pbar is not None:
pbar.close()
return_code = process.poll()
logger.debug("Return code: %s, frames: %s", return_code, frames)
return frames
class ImageIO():
""" Perform disk IO for images or videos in a background thread.
This is the parent thread for :class:`ImagesLoader` and :class:`ImagesSaver` and should not
be called directly.
Parameters
----------
path: str or list
The path to load or save images to/from. For loading this can be a folder which contains
images, video file or a list of image files. For saving this must be an existing folder.
queue_size: int
The amount of images to hold in the internal buffer.
args: tuple, optional
The arguments to be passed to the loader or saver thread. Default: ``None``
See Also
--------
lib.image.ImagesLoader : Background Image Loader inheriting from this class.
lib.image.ImagesSaver : Background Image Saver inheriting from this class.
"""
def __init__(self, path, queue_size, args=None):
logger.debug("Initializing %s: (path: %s, queue_size: %s, args: %s)",
self.__class__.__name__, path, queue_size, args)
self._args = tuple() if args is None else args
self._location = path
self._check_location_exists()
queue_name = queue_manager.add_queue(name=self.__class__.__name__,
maxsize=queue_size,
create_new=True)
self._queue = queue_manager.get_queue(queue_name)
self._thread = None
@property
def location(self):
""" str: The folder or video that was passed in as the :attr:`path` parameter. """
return self._location
def _check_location_exists(self):
""" Check whether the input location exists.
Raises
------
FaceswapError
If the given location does not exist
"""
if isinstance(self.location, str) and not os.path.exists(self.location):
raise FaceswapError("The location '{}' does not exist".format(self.location))
if isinstance(self.location, (list, tuple)) and not all(os.path.exists(location)
for location in self.location):
raise FaceswapError("Not all locations in the input list exist")
def _set_thread(self):
""" Set the background thread for the load and save iterators and launch it. """
logger.trace("Setting thread") # type:ignore[attr-defined]
if self._thread is not None and self._thread.is_alive():
logger.trace("Thread pre-exists and is alive: %s", # type:ignore[attr-defined]
self._thread)
return
self._thread = MultiThread(self._process,
self._queue,
name=self.__class__.__name__,
thread_count=1)
logger.debug("Set thread: %s", self._thread)
self._thread.start()
def _process(self, queue):
""" Image IO process to be run in a thread. Override for loader/saver process.
Parameters
----------
queue: queue.Queue()
The ImageIO Queue
"""
raise NotImplementedError
def close(self):
""" Closes down and joins the internal threads """
logger.debug("Received Close")
if self._thread is not None:
self._thread.join()
del self._thread
self._thread = None
logger.debug("Closed")
class ImagesLoader(ImageIO):
""" Perform image loading from a folder of images or a video.
Images will be loaded and returned in the order that they appear in the folder, or in the video
to ensure deterministic ordering. Loading occurs in a background thread, caching 8 images at a
time so that other processes do not need to wait on disk reads.
See also :class:`ImageIO` for additional attributes.
Parameters
----------
path: str or list
The path to load images from. This can be a folder which contains images a video file or a
list of image files.
queue_size: int, optional
The amount of images to hold in the internal buffer. Default: 8.
fast_count: bool, optional
When loading from video, the video needs to be parsed frame by frame to get an accurate
count. This can be done quite quickly without guaranteed accuracy, or slower with
guaranteed accuracy. Set to ``True`` to count quickly, or ``False`` to count slower
but accurately. Default: ``True``.
skip_list: list, optional
Optional list of frame/image indices to not load. Any indices provided here will be skipped
when executing the :func:`load` function from the given location. Default: ``None``
count: int, optional
If the number of images that the loader will encounter is already known, it can be passed
in here to skip the image counting step, which can save time at launch. Set to ``None`` if
the count is not already known. Default: ``None``
Examples
--------
Loading from a video file:
>>> loader = ImagesLoader('/path/to/video.mp4')
>>> for filename, image in loader.load():
>>> <do processing>
"""
def __init__(self,
path,
queue_size=8,
fast_count=True,
skip_list=None,
count=None):
logger.debug("Initializing %s: (path: %s, queue_size: %s, fast_count: %s, skip_list: %s, "
"count: %s)", self.__class__.__name__, path, queue_size, fast_count,
skip_list, count)
super().__init__(path, queue_size=queue_size)
self._skip_list = set() if skip_list is None else set(skip_list)
self._is_video = self._check_for_video()
self._fps = self._get_fps()
self._count = None
self._file_list = None
self._get_count_and_filelist(fast_count, count)
@property
def count(self):
""" int: The number of images or video frames in the source location. This count includes
any files that will ultimately be skipped if a :attr:`skip_list` has been provided. See
also: :attr:`process_count`"""
return self._count
@property
def process_count(self):
""" int: The number of images or video frames to be processed (IE the total count less
items that are to be skipped from the :attr:`skip_list`)"""
return self._count - len(self._skip_list)
@property
def is_video(self):
""" bool: ``True`` if the input is a video, ``False`` if it is not """
return self._is_video
@property
def fps(self):
""" float: For an input folder of images, this will always return 25fps. If the input is a
video, then the fps of the video will be returned. """
return self._fps
@property
def file_list(self):
""" list: A full list of files in the source location. This includes any files that will
ultimately be skipped if a :attr:`skip_list` has been provided. If the input is a video
then this is a list of dummy filenames as corresponding to an alignments file """
return self._file_list
def add_skip_list(self, skip_list):
""" Add a skip list to this :class:`ImagesLoader`
Parameters
----------
skip_list: list
A list of indices corresponding to the frame indices that should be skipped by the
:func:`load` function.
"""
logger.debug(skip_list)
self._skip_list = set(skip_list)
def _check_for_video(self):
""" Check whether the input is a video
Returns
-------
bool: 'True' if input is a video 'False' if it is a folder.
Raises
------
FaceswapError
If the given location is a file and does not have a valid video extension.
"""
if not isinstance(self.location, str) or os.path.isdir(self.location):
retval = False
elif os.path.splitext(self.location)[1].lower() in _video_extensions:
retval = True
else:
raise FaceswapError("The input file '{}' is not a valid video".format(self.location))
logger.debug("Input '%s' is_video: %s", self.location, retval)
return retval
def _get_fps(self):
""" Get the Frames per Second.
If the input is a folder of images than 25.0 will be returned, as it is not possible to
calculate the fps just from frames alone. For video files the correct FPS will be returned.
Returns
-------
float: The Frames per Second of the input sources
"""
if self._is_video:
reader = imageio.get_reader(self.location, "ffmpeg")
retval = reader.get_meta_data()["fps"]
reader.close()
else:
retval = 25.0
logger.debug(retval)
return retval
def _get_count_and_filelist(self, fast_count, count):
""" Set the count of images to be processed and set the file list
If the input is a video, a dummy file list is created for checking against an
alignments file, otherwise it will be a list of full filenames.
Parameters
----------
fast_count: bool
When loading from video, the video needs to be parsed frame by frame to get an accurate
count. This can be done quite quickly without guaranteed accuracy, or slower with
guaranteed accuracy. Set to ``True`` to count quickly, or ``False`` to count slower
but accurately.
count: int
The number of images that the loader will encounter if already known, otherwise
``None``
"""
if self._is_video:
self._count = int(count_frames(self.location,
fast=fast_count)) if count is None else count
self._file_list = [self._dummy_video_framename(i) for i in range(self.count)]
else:
if isinstance(self.location, (list, tuple)):
self._file_list = self.location
else:
self._file_list = get_image_paths(self.location)
self._count = len(self.file_list) if count is None else count
logger.debug("count: %s", self.count)
logger.trace("filelist: %s", self.file_list)
def _process(self, queue):
""" The load thread.
Loads from a folder of images or from a video and puts to a queue
Parameters
----------
queue: queue.Queue()
The ImageIO Queue
"""
iterator = self._from_video if self._is_video else self._from_folder
logger.debug("Load iterator: %s", iterator)
for retval in iterator():
filename, image = retval[:2]
if image is None or (not image.any() and image.ndim not in (2, 3)):
# All black frames will return not numpy.any() so check dims too
logger.warning("Unable to open image. Skipping: '%s'", filename)
continue
logger.trace("Putting to queue: %s", [v.shape if isinstance(v, np.ndarray) else v
for v in retval])
queue.put(retval)
logger.trace("Putting EOF")
queue.put("EOF")
def _from_video(self):
""" Generator for loading frames from a video
Yields
------
filename: str
The dummy filename of the loaded video frame.
image: numpy.ndarray
The loaded video frame.
"""
logger.debug("Loading frames from video: '%s'", self.location)
reader = imageio.get_reader(self.location, "ffmpeg")
for idx, frame in enumerate(reader):
if idx in self._skip_list:
logger.trace("Skipping frame %s due to skip list", idx)
continue
# Convert to BGR for cv2 compatibility
frame = frame[:, :, ::-1]
filename = self._dummy_video_framename(idx)
logger.trace("Loading video frame: '%s'", filename)
yield filename, frame
reader.close()
def _dummy_video_framename(self, index):
""" Return a dummy filename for video files
Parameters
----------
index: int
The index number for the frame in the video file
Notes
-----
Indexes start at 0, frame numbers start at 1, so index is incremented by 1
when creating the filename
Returns
-------
str: A dummied filename for a video frame """
vidname = os.path.splitext(os.path.basename(self.location))[0]
return "{}_{:06d}.png".format(vidname, index + 1)
def _from_folder(self):
""" Generator for loading images from a folder
Yields
------
filename: str
The filename of the loaded image.
image: numpy.ndarray
The loaded image.
"""
logger.debug("Loading frames from folder: '%s'", self.location)
for idx, filename in enumerate(self.file_list):
if idx in self._skip_list:
logger.trace("Skipping frame %s due to skip list")
continue
image_read = read_image(filename, raise_error=False)
retval = filename, image_read
if retval[1] is None:
logger.warning("Frame not loaded: '%s'", filename)
continue
yield retval
def load(self):
""" Generator for loading images from the given :attr:`location`
If :class:`FacesLoader` is in use then the Faceswap metadata of the image stored in the
image exif file is added as the final item in the output `tuple`.
Yields
------
filename: str
The filename of the loaded image.
image: numpy.ndarray
The loaded image.
metadata: dict, (:class:`FacesLoader` only)
The Faceswap metadata associated with the loaded image.
"""
logger.debug("Initializing Load Generator")
self._set_thread()
while True:
self._thread.check_and_raise_error()
try:
retval = self._queue.get(True, 1)
except QueueEmpty:
continue
if retval == "EOF":
logger.trace("Got EOF")
break
logger.trace("Yielding: %s", [v.shape if isinstance(v, np.ndarray) else v
for v in retval])
yield retval
logger.debug("Closing Load Generator")
self.close()
class FacesLoader(ImagesLoader):
""" Loads faces from a faces folder along with the face's Faceswap metadata.
Examples
--------
Loading faces with their Faceswap metadata:
>>> loader = FacesLoader('/path/to/faces/folder')
>>> for filename, face, metadata in loader.load():
>>> <do processing>
"""
def __init__(self, path, skip_list=None, count=None):
logger.debug("Initializing %s: (path: %s, count: %s)", self.__class__.__name__,
path, count)
super().__init__(path, queue_size=8, skip_list=skip_list, count=count)
def _get_count_and_filelist(self, fast_count, count):
""" Override default implementation to only return png files from the source folder
Parameters
----------
fast_count: bool
Not used for faces loader
count: int
The number of images that the loader will encounter if already known, otherwise
``None``
"""
if isinstance(self.location, (list, tuple)):
file_list = self.location
else:
file_list = get_image_paths(self.location)
self._file_list = [fname for fname in file_list
if os.path.splitext(fname)[-1].lower() == ".png"]
self._count = len(self.file_list) if count is None else count
logger.debug("count: %s", self.count)
logger.trace("filelist: %s", self.file_list)
def _from_folder(self):
""" Generator for loading images from a folder
Faces will only ever be loaded from a folder, so this is the only function requiring
an override
Yields
------
filename: str
The filename of the loaded image.
image: numpy.ndarray
The loaded image.
metadata: dict
The Faceswap metadata associated with the loaded image.
"""
logger.debug("Loading images from folder: '%s'", self.location)
for idx, filename in enumerate(self.file_list):
if idx in self._skip_list:
logger.trace("Skipping face %s due to skip list")
continue
image_read = read_image(filename, raise_error=False, with_metadata=True)
retval = filename, *image_read
if retval[1] is None:
logger.warning("Face not loaded: '%s'", filename)
continue
yield retval
class SingleFrameLoader(ImagesLoader):
""" Allows direct access to a frame by filename or frame index.
As we are interested in instant access to frames, there is no requirement to process in a
background thread, as either way we need to wait for the frame to load.
Parameters
----------
video_meta_data: dict, optional
Existing video meta information containing the pts_time and iskey flags for the given
video. Used in conjunction with single_frame_reader for faster seeks. Providing this means
that the video does not need to be scanned again. Set to ``None`` if the video is to be
scanned. Default: ``None``
"""
def __init__(self, path, video_meta_data=None):
logger.debug("Initializing %s: (path: %s, video_meta_data: %s)",
self.__class__.__name__, path, video_meta_data)
self._video_meta_data = dict() if video_meta_data is None else video_meta_data
self._reader = None
super().__init__(path, queue_size=1, fast_count=False)
@property
def video_meta_data(self):
""" dict: For videos contains the keys `frame_pts` holding a list of time stamps for each
frame and `keyframes` holding the frame index of each key frame.
Notes
-----
Only populated if the input is a video and single frame reader is being used, otherwise
returns ``None``.
"""
return self._video_meta_data
def _get_count_and_filelist(self, fast_count, count):
if self._is_video:
self._reader = imageio.get_reader(self.location, "ffmpeg")
self._reader.use_patch = True
count, video_meta_data = self._reader.get_frame_info(
frame_pts=self._video_meta_data.get("pts_time", None),
keyframes=self._video_meta_data.get("keyframes", None))
self._video_meta_data = video_meta_data
super()._get_count_and_filelist(fast_count, count)
def image_from_index(self, index):
""" Return a single image from :attr:`file_list` for the given index.
Parameters
----------
index: int
The index number (frame number) of the frame to retrieve. NB: The first frame is
index `0`
Returns
-------
filename: str
The filename of the returned image
image: :class:`numpy.ndarray`
The image for the given index
Notes
-----
Retrieving frames from video files can be slow as the whole video file needs to be
iterated to retrieve the requested frame. If a frame has already been retrieved, then
retrieving frames of a higher index will be quicker than retrieving frames of a lower
index, as iteration needs to start from the beginning again when navigating backwards.
We do not use a background thread for this task, as it is assumed that requesting an image
by index will be done when required.
"""
if self.is_video:
image = self._reader.get_data(index)[..., ::-1]
filename = self._dummy_video_framename(index)
else:
filename = self.file_list[index]
image = read_image(filename, raise_error=True)
filename = os.path.basename(filename)
logger.trace("index: %s, filename: %s image shape: %s", index, filename, image.shape)
return filename, image
class ImagesSaver(ImageIO):
""" Perform image saving to a destination folder.
Images are saved in a background ThreadPoolExecutor to allow for concurrent saving.
See also :class:`ImageIO` for additional attributes.
Parameters
----------
path: str
The folder to save images to. This must be an existing folder.
queue_size: int, optional
The amount of images to hold in the internal buffer. Default: 8.
as_bytes: bool, optional
``True`` if the image is already encoded to bytes, ``False`` if the image is a
:class:`numpy.ndarray`. Default: ``False``.
Examples
--------
>>> saver = ImagesSaver('/path/to/save/folder')
>>> for filename, image in <image_iterator>:
>>> saver.save(filename, image)
>>> saver.close()
"""
def __init__(self, path, queue_size=8, as_bytes=False):
logger.debug("Initializing %s: (path: %s, queue_size: %s, as_bytes: %s)",
self.__class__.__name__, path, queue_size, as_bytes)
super().__init__(path, queue_size=queue_size)
self._as_bytes = as_bytes
def _check_location_exists(self):
""" Check whether the output location exists and is a folder
Raises
------
FaceswapError
If the given location does not exist or the location is not a folder
"""
if not isinstance(self.location, str):
raise FaceswapError("The output location must be a string not a "
"{}".format(type(self.location)))
super()._check_location_exists()
if not os.path.isdir(self.location):
raise FaceswapError("The output location '{}' is not a folder".format(self.location))
def _process(self, queue):
""" Saves images from the save queue to the given :attr:`location` inside a thread.
Parameters
----------
queue: queue.Queue()
The ImageIO Queue
"""
executor = futures.ThreadPoolExecutor(thread_name_prefix=self.__class__.__name__)
while True:
item = queue.get()
if item == "EOF":
logger.debug("EOF received")
break
logger.trace("Submitting: '%s'", item[0])
executor.submit(self._save, *item)
executor.shutdown()
def _save(self,
filename: str,
image: bytes | np.ndarray,
sub_folder: str | None) -> None:
""" Save a single image inside a ThreadPoolExecutor
Parameters
----------
filename: str
The filename of the image to be saved. NB: Any folders passed in with the filename
will be stripped and replaced with :attr:`location`.
image: bytes or :class:`numpy.ndarray`
The encoded image or numpy array to be saved
subfolder: str or ``None``
If the file should be saved in a subfolder in the output location, the subfolder should
be provided here. ``None`` for no subfolder.
"""
location = os.path.join(self.location, sub_folder) if sub_folder else self._location
if sub_folder and not os.path.exists(location):
os.makedirs(location)
filename = os.path.join(location, os.path.basename(filename))
try:
if self._as_bytes:
assert isinstance(image, bytes)
with open(filename, "wb") as out_file:
out_file.write(image)
else:
cv2.imwrite(filename, image)
logger.trace("Saved image: '%s'", filename) # type:ignore
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, str(err))
del image
del filename
def save(self,
filename: str,
image: bytes | np.ndarray,
sub_folder: str | None = None) -> None:
""" Save the given image in the background thread
Ensure that :func:`close` is called once all save operations are complete.
Parameters
----------
filename: str
The filename of the image to be saved. NB: Any folders passed in with the filename
will be stripped and replaced with :attr:`location`.
image: bytes
The encoded image to be saved
subfolder: str, optional
If the file should be saved in a subfolder in the output location, the subfolder should
be provided here. ``None`` for no subfolder. Default: ``None``
"""
self._set_thread()
logger.trace("Putting to save queue: '%s'", filename) # type:ignore
self._queue.put((filename, image, sub_folder))
def close(self):
""" Signal to the Save Threads that they should be closed and cleanly shutdown
the saver """
logger.debug("Putting EOF to save queue")
self._queue.put("EOF")
super().close()