mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
* Separate predict and implement pool * Add check and raise error to multithreading Box functions to config. Add crop box option. * All masks to mask module. Refactor convert masks Predicted mask passed from model. Cli update * Intesect box with mask and fixes * Use raw NN output for convert Use raw mask for face adjustments. Split adjustments to pre and post warp * Separate out adjustments. Add unmask sharpen * Set sensible defaults. Pre PR Testing * Fix queue sizes. Move masked.py to lib * Fix Skip Frames. Fix GUI Config popup * Sensible queue limits. Add a less resource intensive single processing mode * Fix predicted mask. Amend smooth box defaults * Deterministic ordering for video output * Video to Video convert implemented * Fixups - Remove defaults from folders across all stages - Move match-hist and aca into color adjustments selectable - Handle crashes properly for pooled processes - Fix output directory does not exist error when creating a new output folder - Force output to frames if input is not a video * Add Color Transfer adjustment method Wrap info text in GUI plugin configure popup * Refactor image adjustments. Easier to create plugins Start implementing config options for video encoding * Add video encoding config options Allow video encoding for frames input (must pass in a reference video) Move video and image output writers to plugins * Image writers config options Move scaling to cli Move draw_transparent to images config Add config options for cv2 writer Add Pillow image writer * Add gif filetype to Pillow. Fix draw transparent for Pillow * Add Animated GIF writer standardize opencv/pillow defaults * [speedup] Pre-encode supported writers in the convert pool (opencv, pillow) Move scaling to convert pool Remove dfaker mask * Fix default writer * Bugfixes * Better custom argparse formatting
683 lines
27 KiB
Python
683 lines
27 KiB
Python
#!/usr/bin python3
|
|
""" The script to run the convert process of faceswap """
|
|
|
|
import logging
|
|
import re
|
|
import os
|
|
import sys
|
|
from time import sleep
|
|
from threading import Event
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from scripts.fsmedia import Alignments, Images, PostProcess, Utils
|
|
from lib import Serializer
|
|
from lib.convert import Converter
|
|
from lib.faces_detect import DetectedFace
|
|
from lib.multithreading import MultiThread, PoolProcess, total_cpus
|
|
from lib.queue_manager import queue_manager, QueueEmpty
|
|
from lib.utils import get_folder, get_image_paths, hash_image_file
|
|
from plugins.plugin_loader import PluginLoader
|
|
|
|
from .extract import Plugins as Extractor
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class Convert():
|
|
""" The convert process. """
|
|
def __init__(self, arguments):
|
|
logger.debug("Initializing %s: (args: %s)", self.__class__.__name__, arguments)
|
|
self.args = arguments
|
|
Utils.set_verbosity(self.args.loglevel)
|
|
|
|
self.images = Images(self.args)
|
|
self.validate()
|
|
self.alignments = Alignments(self.args, False, self.images.is_video)
|
|
# Update Legacy alignments
|
|
Legacy(self.alignments, self.images.input_images, arguments.input_aligned_dir)
|
|
self.opts = OptionalActions(self.args, self.images.input_images, self.alignments)
|
|
|
|
self.add_queues()
|
|
self.disk_io = DiskIO(self.alignments, self.images, arguments)
|
|
self.extractor = None
|
|
self.predictor = Predict(self.disk_io.load_queue, self.queue_size, arguments)
|
|
self.converter = Converter(get_folder(self.args.output_dir),
|
|
self.predictor.output_size,
|
|
self.predictor.has_predicted_mask,
|
|
self.disk_io.draw_transparent,
|
|
self.disk_io.pre_encode,
|
|
arguments)
|
|
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def queue_size(self):
|
|
""" Set q-size to double number of cpus available """
|
|
if self.args.singleprocess:
|
|
retval = 2
|
|
else:
|
|
retval = total_cpus() * 2
|
|
logger.debug(retval)
|
|
return retval
|
|
|
|
@property
|
|
def pool_processes(self):
|
|
""" return the maximum number of pooled processes to use """
|
|
if self.args.singleprocess:
|
|
retval = 1
|
|
else:
|
|
retval = min(total_cpus(), self.images.images_found)
|
|
logger.debug(retval)
|
|
return retval
|
|
|
|
def validate(self):
|
|
""" Make the output folder if it doesn't exist and check that video flag is
|
|
a valid choice """
|
|
if (self.args.writer == "ffmpeg" and
|
|
not self.images.is_video and
|
|
self.args.reference_video is None):
|
|
logger.error("Output as video selected, but using frames as input. You must provide a "
|
|
"reference video ('-ref', '--reference-video').")
|
|
exit(1)
|
|
output_dir = get_folder(self.args.output_dir)
|
|
logger.info("Output Directory: %s", output_dir)
|
|
|
|
def add_queues(self):
|
|
""" Add the queues for convert """
|
|
logger.debug("Adding queues. Queue size: %s", self.queue_size)
|
|
for qname in ("convert_in", "save", "patch"):
|
|
queue_manager.add_queue(qname, self.queue_size)
|
|
|
|
def process(self):
|
|
""" Process the conversion """
|
|
logger.debug("Starting Conversion")
|
|
# queue_manager.debug_monitor(2)
|
|
self.convert_images()
|
|
self.disk_io.save_thread.join()
|
|
queue_manager.terminate_queues()
|
|
|
|
Utils.finalize(self.images.images_found,
|
|
self.predictor.faces_count,
|
|
self.predictor.verify_output)
|
|
logger.debug("Completed Conversion")
|
|
|
|
def convert_images(self):
|
|
""" Convert the images """
|
|
logger.debug("Converting images")
|
|
save_queue = queue_manager.get_queue("save")
|
|
patch_queue = queue_manager.get_queue("patch")
|
|
pool = PoolProcess(self.converter.process, patch_queue, save_queue,
|
|
processes=self.pool_processes)
|
|
pool.start()
|
|
while True:
|
|
self.check_thread_error()
|
|
if self.disk_io.completion_event.is_set():
|
|
break
|
|
sleep(1)
|
|
pool.join()
|
|
|
|
save_queue.put("EOF")
|
|
logger.debug("Converted images")
|
|
|
|
def check_thread_error(self):
|
|
""" Check and raise thread errors """
|
|
for thread in (self.predictor.thread, self.disk_io.load_thread, self.disk_io.save_thread):
|
|
thread.check_and_raise_error()
|
|
|
|
def patch_iterator(self, processes):
|
|
""" Prepare the images for conversion """
|
|
out_queue = queue_manager.get_queue("out")
|
|
completed = 0
|
|
|
|
while True:
|
|
try:
|
|
item = out_queue.get(True, 1)
|
|
except QueueEmpty:
|
|
self.check_thread_error()
|
|
continue
|
|
self.check_thread_error()
|
|
|
|
if item == "EOF":
|
|
completed += 1
|
|
logger.debug("Got EOF %s of %s", completed, processes)
|
|
if completed == processes:
|
|
break
|
|
continue
|
|
|
|
logger.trace("Yielding: '%s'", item[0])
|
|
yield item
|
|
logger.debug("iterator exhausted")
|
|
return "EOF"
|
|
|
|
|
|
class DiskIO():
|
|
""" Background threads to:
|
|
Load images from disk and get the detected faces
|
|
Save images back to disk """
|
|
def __init__(self, alignments, images, arguments):
|
|
logger.debug("Initializing %s: (alignments: %s, images: %s, arguments: %s)",
|
|
self.__class__.__name__, alignments, images, arguments)
|
|
self.alignments = alignments
|
|
self.images = images
|
|
self.args = arguments
|
|
self.completion_event = Event()
|
|
self.frame_ranges = self.get_frame_ranges()
|
|
self.writer = self.get_writer()
|
|
|
|
# For frame skipping
|
|
self.imageidxre = re.compile(r"(\d+)(?!.*\d\.)(?=\.\w+$)")
|
|
|
|
# Extractor for on the fly detection
|
|
self.extractor = None
|
|
if not self.alignments.have_alignments_file:
|
|
self.load_extractor()
|
|
|
|
self.load_queue = None
|
|
self.save_queue = None
|
|
self.load_thread = None
|
|
self.save_thread = None
|
|
self.init_threads()
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def draw_transparent(self):
|
|
""" Draw transparent is an image writer only parameter.
|
|
Return the value here for easy access for predictor """
|
|
return self.writer.config.get("draw_transparent", False)
|
|
|
|
@property
|
|
def pre_encode(self):
|
|
""" Return the writer's pre-encoder """
|
|
dummy = np.zeros((20, 20, 3)).astype("uint8")
|
|
test = self.writer.pre_encode(dummy)
|
|
retval = None if test is None else self.writer.pre_encode
|
|
logger.debug("Writer pre_encode function: %s", retval)
|
|
return retval
|
|
|
|
@property
|
|
def total_count(self):
|
|
""" Return the total number of frames to be converted """
|
|
if self.frame_ranges and not self.args.keep_unchanged:
|
|
retval = sum([fr[1] - fr[0] for fr in self.frame_ranges])
|
|
else:
|
|
retval = self.images.images_found
|
|
logger.debug(retval)
|
|
return retval
|
|
|
|
# Initalization
|
|
def get_writer(self):
|
|
""" Return the writer plugin """
|
|
args = [self.args.output_dir]
|
|
if self.args.writer in ("ffmpeg", "gif"):
|
|
args.append(self.total_count)
|
|
if self.args.writer == "ffmpeg":
|
|
if self.images.is_video:
|
|
args.append(self.args.input_dir)
|
|
else:
|
|
args.append(self.args.reference_video)
|
|
logger.debug("Writer args: %s", args)
|
|
return PluginLoader.get_converter("writer", self.args.writer)(*args)
|
|
|
|
def get_frame_ranges(self):
|
|
""" split out the frame ranges and parse out 'min' and 'max' values """
|
|
if not self.args.frame_ranges:
|
|
logger.debug("No frame range set")
|
|
return None
|
|
|
|
minmax = {"min": 0, # never any frames less than 0
|
|
"max": float("inf")}
|
|
retval = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), v.split("-")))
|
|
for v in self.args.frame_ranges]
|
|
logger.debug("frame ranges: %s", retval)
|
|
return retval
|
|
|
|
def load_extractor(self):
|
|
""" Set on the fly extraction """
|
|
logger.debug("Loading extractor")
|
|
logger.warning("No Alignments file found. Extracting on the fly.")
|
|
logger.warning("NB: This will use the inferior dlib-hog for extraction "
|
|
"and dlib pose predictor for landmarks. It is recommended "
|
|
"to perfom Extract first for superior results")
|
|
extract_args = {"detector": "dlib-hog",
|
|
"aligner": "dlib",
|
|
"loglevel": self.args.loglevel}
|
|
self.extractor = Extractor(None, extract_args)
|
|
self.extractor.launch_detector()
|
|
self.extractor.launch_aligner()
|
|
logger.debug("Loaded extractor")
|
|
|
|
def init_threads(self):
|
|
""" Initialize queues and threads """
|
|
logger.debug("Initializing DiskIO Threads")
|
|
for task in ("load", "save"):
|
|
self.add_queue(task)
|
|
self.start_thread(task)
|
|
logger.debug("Initialized DiskIO Threads")
|
|
|
|
def add_queue(self, task):
|
|
""" Add the queue to queue_manager and set queue attribute """
|
|
logger.debug("Adding queue for task: '%s'", task)
|
|
q_name = "convert_in" if task == "load" else task
|
|
setattr(self, "{}_queue".format(task), queue_manager.get_queue(q_name))
|
|
logger.debug("Added queue for task: '%s'", task)
|
|
|
|
def start_thread(self, task):
|
|
""" Start the DiskIO thread """
|
|
logger.debug("Starting thread: '%s'", task)
|
|
args = self.completion_event if task == "save" else None
|
|
func = getattr(self, task)
|
|
io_thread = MultiThread(func, args, thread_count=1)
|
|
io_thread.start()
|
|
setattr(self, "{}_thread".format(task), io_thread)
|
|
logger.debug("Started thread: '%s'", task)
|
|
|
|
# Loading tasks
|
|
def load(self, *args): # pylint: disable=unused-argument
|
|
""" Load the images with detected_faces"""
|
|
logger.debug("Load Images: Start")
|
|
extract_queue = queue_manager.get_queue("extract_in") if self.extractor else None
|
|
idx = 0
|
|
for filename, image in self.images.load():
|
|
idx += 1
|
|
if self.load_queue.shutdown.is_set():
|
|
logger.debug("Load Queue: Stop signal received. Terminating")
|
|
break
|
|
if image is None or not image.any():
|
|
logger.warning("Unable to open image. Skipping: '%s'", filename)
|
|
continue
|
|
if self.check_skipframe(filename):
|
|
if self.args.keep_unchanged:
|
|
logger.trace("Saving unchanged frame: %s", filename)
|
|
out_file = os.path.join(self.args.output_dir, os.path.basename(filename))
|
|
self.save_queue.put((out_file, image))
|
|
else:
|
|
logger.trace("Discarding frame: '%s'", filename)
|
|
continue
|
|
|
|
detected_faces = self.get_detected_faces(filename, image, extract_queue)
|
|
item = dict(filename=filename, image=image, detected_faces=detected_faces)
|
|
self.load_queue.put(item)
|
|
|
|
self.load_queue.put("EOF")
|
|
logger.debug("Load Images: Complete")
|
|
|
|
def check_skipframe(self, filename):
|
|
""" Check whether frame is to be skipped """
|
|
if not self.frame_ranges:
|
|
return None
|
|
indices = self.imageidxre.findall(filename)
|
|
if not indices:
|
|
logger.warning("Could not determine frame number. Frame will be converted: '%s'",
|
|
filename)
|
|
return False
|
|
idx = int(indices[0]) if indices else None
|
|
skipframe = not any(map(lambda b: b[0] <= idx <= b[1], self.frame_ranges))
|
|
return skipframe
|
|
|
|
def get_detected_faces(self, filename, image, extract_queue):
|
|
""" Return detected faces from alignments or detector """
|
|
logger.trace("Getting faces for: '%s'", filename)
|
|
if not self.extractor:
|
|
detected_faces = self.alignments_faces(os.path.basename(filename), image)
|
|
else:
|
|
detected_faces = self.detect_faces(extract_queue, filename, image)
|
|
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename)
|
|
return detected_faces
|
|
|
|
def alignments_faces(self, frame, image):
|
|
""" Get the face from alignments file """
|
|
if not self.check_alignments(frame):
|
|
return list()
|
|
|
|
faces = self.alignments.get_faces_in_frame(frame)
|
|
detected_faces = list()
|
|
|
|
for rawface in faces:
|
|
face = DetectedFace()
|
|
face.from_alignment(rawface, image=image)
|
|
detected_faces.append(face)
|
|
return detected_faces
|
|
|
|
def check_alignments(self, frame):
|
|
""" If we have no alignments for this image, skip it """
|
|
have_alignments = self.alignments.frame_exists(frame)
|
|
if not have_alignments:
|
|
tqdm.write("No alignment found for {}, "
|
|
"skipping".format(frame))
|
|
return have_alignments
|
|
|
|
def detect_faces(self, load_queue, filename, image):
|
|
""" Extract the face from a frame (If alignments file not found) """
|
|
inp = {"filename": filename,
|
|
"image": image}
|
|
load_queue.put(inp)
|
|
faces = next(self.extractor.detect_faces())
|
|
|
|
landmarks = faces["landmarks"]
|
|
detected_faces = faces["detected_faces"]
|
|
final_faces = list()
|
|
|
|
for idx, face in enumerate(detected_faces):
|
|
detected_face = DetectedFace()
|
|
detected_face.from_dlib_rect(face)
|
|
detected_face.landmarksXY = landmarks[idx]
|
|
final_faces.append(detected_face)
|
|
return final_faces
|
|
|
|
# Saving tasks
|
|
def save(self, completion_event):
|
|
""" Save the converted images """
|
|
logger.debug("Save Images: Start")
|
|
for _ in tqdm(range(self.total_count), desc="Converting", file=sys.stdout):
|
|
if self.save_queue.shutdown.is_set():
|
|
logger.debug("Save Queue: Stop signal received. Terminating")
|
|
break
|
|
item = self.save_queue.get()
|
|
if item == "EOF":
|
|
break
|
|
filename, image = item
|
|
self.writer.write(filename, image)
|
|
self.writer.close()
|
|
completion_event.set()
|
|
logger.debug("Save Faces: Complete")
|
|
|
|
|
|
class Predict():
|
|
""" Predict faces from incoming queue """
|
|
def __init__(self, in_queue, queue_size, arguments):
|
|
logger.debug("Initializing %s: (args: %s, queue_size: %s, in_queue: %s)",
|
|
self.__class__.__name__, arguments, queue_size, in_queue)
|
|
self.batchsize = min(queue_size, 16)
|
|
self.args = arguments
|
|
self.in_queue = in_queue
|
|
self.out_queue = queue_manager.get_queue("patch")
|
|
self.serializer = Serializer.get_serializer("json")
|
|
self.faces_count = 0
|
|
self.verify_output = False
|
|
self.pre_process = PostProcess(arguments)
|
|
self.model = self.load_model()
|
|
self.predictor = self.model.converter(self.args.swap_model)
|
|
self.queues = dict()
|
|
|
|
self.thread = MultiThread(self.predict_faces, thread_count=1)
|
|
self.thread.start()
|
|
logger.debug("Initialized %s: (out_queue: %s)", self.__class__.__name__, self.out_queue)
|
|
|
|
@property
|
|
def coverage_ratio(self):
|
|
""" Return coverage ratio from training options """
|
|
return self.model.training_opts["coverage_ratio"]
|
|
|
|
@property
|
|
def input_size(self):
|
|
""" Return the model input size """
|
|
return self.model.input_shape[0]
|
|
|
|
@property
|
|
def output_size(self):
|
|
""" Return the model output size """
|
|
return self.model.output_shape[0]
|
|
|
|
@property
|
|
def input_mask(self):
|
|
""" Return the input mask """
|
|
mask = np.zeros(self.model.state.mask_shapes[0], dtype="float32")
|
|
retval = np.expand_dims(mask, 0)
|
|
return retval
|
|
|
|
@property
|
|
def has_predicted_mask(self):
|
|
""" Return whether this model has a predicted mask """
|
|
return bool(self.model.state.mask_shapes)
|
|
|
|
def load_model(self):
|
|
""" Load the model requested for conversion """
|
|
logger.debug("Loading Model")
|
|
model_dir = get_folder(self.args.model_dir, make_folder=False)
|
|
if not model_dir:
|
|
logger.error("%s does not exist.", self.args.model_dir)
|
|
exit(1)
|
|
trainer = self.get_trainer(model_dir)
|
|
model = PluginLoader.get_model(trainer)(model_dir, self.args.gpus, predict=True)
|
|
logger.debug("Loaded Model")
|
|
return model
|
|
|
|
def get_trainer(self, model_dir):
|
|
""" Return the trainer name if provided, or read from state file """
|
|
if self.args.trainer:
|
|
logger.debug("Trainer name provided: '%s'", self.args.trainer)
|
|
return self.args.trainer
|
|
|
|
statefile = [fname for fname in os.listdir(str(model_dir))
|
|
if fname.endswith("_state.json")]
|
|
if len(statefile) != 1:
|
|
logger.error("There should be 1 state file in your model folder. %s were found. "
|
|
"Specify a trainer with the '-t', '--trainer' option.")
|
|
exit(1)
|
|
statefile = os.path.join(str(model_dir), statefile[0])
|
|
|
|
with open(statefile, "rb") as inp:
|
|
state = self.serializer.unmarshal(inp.read().decode("utf-8"))
|
|
trainer = state.get("name", None)
|
|
|
|
if not trainer:
|
|
logger.error("Trainer name could not be read from state file. "
|
|
"Specify a trainer with the '-t', '--trainer' option.")
|
|
exit(1)
|
|
logger.debug("Trainer from state file: '%s'", trainer)
|
|
return trainer
|
|
|
|
def predict_faces(self):
|
|
""" Get detected faces from images """
|
|
faces_seen = 0
|
|
batch = list()
|
|
while True:
|
|
item = self.in_queue.get()
|
|
if item != "EOF":
|
|
logger.trace("Got from queue: '%s'", item["filename"])
|
|
faces_count = len(item["detected_faces"])
|
|
if faces_count != 0:
|
|
self.pre_process.do_actions(item)
|
|
self.faces_count += faces_count
|
|
if faces_count > 1:
|
|
self.verify_output = True
|
|
logger.verbose("Found more than one face in an image! '%s'",
|
|
os.path.basename(item["filename"]))
|
|
|
|
self.load_aligned(item)
|
|
|
|
faces_seen += faces_count
|
|
batch.append(item)
|
|
|
|
if faces_seen < self.batchsize and item != "EOF":
|
|
logger.trace("Continuing. Current batchsize: %s", faces_seen)
|
|
continue
|
|
|
|
if batch:
|
|
detected_batch = [detected_face for item in batch
|
|
for detected_face in item["detected_faces"]]
|
|
feed_faces = self.compile_feed_faces(detected_batch)
|
|
predicted = self.predict(feed_faces)
|
|
|
|
self.queue_out_frames(batch, predicted)
|
|
|
|
faces_seen = 0
|
|
batch = list()
|
|
if item == "EOF":
|
|
logger.debug("Load queue complete")
|
|
break
|
|
self.out_queue.put("EOF")
|
|
|
|
def load_aligned(self, item):
|
|
""" Load the feed faces and reference output faces """
|
|
logger.trace("Loading aligned faces: '%s'", item["filename"])
|
|
for detected_face in item["detected_faces"]:
|
|
detected_face.load_feed_face(item["image"],
|
|
size=self.input_size,
|
|
coverage_ratio=self.coverage_ratio,
|
|
dtype="float32")
|
|
if self.input_size == self.output_size:
|
|
detected_face.reference = detected_face.feed
|
|
else:
|
|
detected_face.load_reference_face(item["image"],
|
|
size=self.output_size,
|
|
coverage_ratio=self.coverage_ratio,
|
|
dtype="float32")
|
|
logger.trace("Loaded aligned faces: '%s'", item["filename"])
|
|
|
|
@staticmethod
|
|
def compile_feed_faces(detected_faces):
|
|
""" Compile the faces for feeding into the predictor """
|
|
logger.trace("Compiling feed face. Batchsize: %s", len(detected_faces))
|
|
feed_faces = np.stack([detected_face.feed_face for detected_face in detected_faces])
|
|
logger.trace("Compiled Feed faces. Shape: %s", feed_faces.shape)
|
|
return feed_faces
|
|
|
|
def predict(self, feed_faces):
|
|
""" Perform inference on the feed """
|
|
logger.trace("Predicting: Batchsize: %s", len(feed_faces))
|
|
feed = [feed_faces]
|
|
if self.has_predicted_mask:
|
|
feed.append(np.repeat(self.input_mask, feed_faces.shape[0], axis=0))
|
|
logger.trace("Input shape(s): %s", [item.shape for item in feed])
|
|
|
|
predicted = self.predictor(feed)
|
|
predicted = predicted if isinstance(predicted, list) else [predicted]
|
|
logger.trace("Output shape(s): %s", [predict.shape for predict in predicted])
|
|
|
|
# Compile masks into alpha channel or keep raw faces
|
|
predicted = np.concatenate(predicted, axis=-1) if len(predicted) == 2 else predicted[0]
|
|
predicted = predicted.astype("float32")
|
|
|
|
logger.trace("Final shape: %s", predicted.shape)
|
|
return predicted
|
|
|
|
def queue_out_frames(self, batch, swapped_faces):
|
|
""" Compile the batch back to original frames and put to out_queue """
|
|
logger.trace("Queueing out batch. Batchsize: %s", len(batch))
|
|
pointer = 0
|
|
for item in batch:
|
|
num_faces = len(item["detected_faces"])
|
|
if num_faces == 0:
|
|
item["swapped_faces"] = np.array(list())
|
|
else:
|
|
item["swapped_faces"] = swapped_faces[pointer:pointer + num_faces]
|
|
|
|
logger.trace("Putting to queue. ('%s', detected_faces: %s, swapped_faces: %s)",
|
|
item["filename"], len(item["detected_faces"]),
|
|
item["swapped_faces"].shape[0])
|
|
self.out_queue.put(item)
|
|
pointer += num_faces
|
|
logger.trace("Queued out batch. Batchsize: %s", len(batch))
|
|
|
|
|
|
class OptionalActions():
|
|
""" Process the optional actions for convert """
|
|
|
|
def __init__(self, args, input_images, alignments):
|
|
logger.debug("Initializing %s", self.__class__.__name__)
|
|
self.args = args
|
|
self.input_images = input_images
|
|
self.alignments = alignments
|
|
|
|
self.remove_skipped_faces()
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
# SKIP FACES #
|
|
def remove_skipped_faces(self):
|
|
""" Remove deleted faces from the loaded alignments """
|
|
logger.debug("Filtering Faces")
|
|
face_hashes = self.get_face_hashes()
|
|
if not face_hashes:
|
|
logger.debug("No face hashes. Not skipping any faces")
|
|
return
|
|
pre_face_count = self.alignments.faces_count
|
|
self.alignments.filter_hashes(face_hashes, filter_out=False)
|
|
logger.info("Faces filtered out: %s", pre_face_count - self.alignments.faces_count)
|
|
|
|
def get_face_hashes(self):
|
|
""" Check for the existence of an aligned directory for identifying
|
|
which faces in the target frames should be swapped.
|
|
If it exists, obtain the hashes of the faces in the folder """
|
|
face_hashes = list()
|
|
input_aligned_dir = self.args.input_aligned_dir
|
|
|
|
if input_aligned_dir is None:
|
|
logger.verbose("Aligned directory not specified. All faces listed in the "
|
|
"alignments file will be converted")
|
|
elif not os.path.isdir(input_aligned_dir):
|
|
logger.warning("Aligned directory not found. All faces listed in the "
|
|
"alignments file will be converted")
|
|
else:
|
|
file_list = [path for path in get_image_paths(input_aligned_dir)]
|
|
logger.info("Getting Face Hashes for selected Aligned Images")
|
|
for face in tqdm(file_list, desc="Hashing Faces"):
|
|
face_hashes.append(hash_image_file(face))
|
|
logger.debug("Face Hashes: %s", (len(face_hashes)))
|
|
if not face_hashes:
|
|
logger.error("Aligned directory is empty, no faces will be converted!")
|
|
exit(1)
|
|
elif len(face_hashes) <= len(self.input_images) / 3:
|
|
logger.warning("Aligned directory contains far fewer images than the input "
|
|
"directory, are you sure this is the right folder?")
|
|
return face_hashes
|
|
|
|
|
|
class Legacy():
|
|
""" Update legacy alignments:
|
|
- Rotate landmarks and bounding boxes on legacy alignments
|
|
and remove the 'r' parameter
|
|
- Add face hashes to alignments file
|
|
"""
|
|
def __init__(self, alignments, frames, faces_dir):
|
|
self.alignments = alignments
|
|
self.frames = {os.path.basename(frame): frame
|
|
for frame in frames}
|
|
self.process(faces_dir)
|
|
|
|
def process(self, faces_dir):
|
|
""" Run the rotate alignments process """
|
|
rotated = self.alignments.get_legacy_rotation()
|
|
hashes = self.alignments.get_legacy_no_hashes()
|
|
if not rotated and not hashes:
|
|
return
|
|
if rotated:
|
|
logger.info("Legacy rotated frames found. Converting...")
|
|
self.rotate_landmarks(rotated)
|
|
self.alignments.save()
|
|
if hashes and faces_dir:
|
|
logger.info("Legacy alignments found. Adding Face Hashes...")
|
|
self.add_hashes(hashes, faces_dir)
|
|
self.alignments.save()
|
|
|
|
def rotate_landmarks(self, rotated):
|
|
""" Rotate the landmarks """
|
|
for rotate_item in tqdm(rotated, desc="Rotating Landmarks"):
|
|
frame = self.frames.get(rotate_item, None)
|
|
if frame is None:
|
|
logger.debug("Skipping missing frame: '%s'", rotate_item)
|
|
continue
|
|
self.alignments.rotate_existing_landmarks(rotate_item, frame)
|
|
|
|
def add_hashes(self, hashes, faces_dir):
|
|
""" Add Face Hashes to the alignments file """
|
|
all_faces = dict()
|
|
face_files = sorted(face for face in os.listdir(faces_dir) if "_" in face)
|
|
for face in face_files:
|
|
filename, extension = os.path.splitext(face)
|
|
index = filename[filename.rfind("_") + 1:]
|
|
if not index.isdigit():
|
|
continue
|
|
orig_frame = filename[:filename.rfind("_")] + extension
|
|
all_faces.setdefault(orig_frame, dict())[int(index)] = os.path.join(faces_dir, face)
|
|
|
|
for frame in tqdm(hashes):
|
|
if frame not in all_faces.keys():
|
|
logger.warning("Skipping missing frame: '%s'", frame)
|
|
continue
|
|
hash_faces = all_faces[frame]
|
|
for index, face_path in hash_faces.items():
|
|
hash_faces[index] = hash_image_file(face_path)
|
|
self.alignments.add_face_hashes(frame, hash_faces)
|