1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/plugins/train/trainer/_base.py
torzdf 66ed005ef3
Optimize Data Augmentation (#881)
* Move image utils to lib.image
* Add .pylintrc file
* Remove some cv2 pylint ignores
* TrainingData: Load images from disk in batches
* TrainingData: get_landmarks to batch
* TrainingData: transform and flip to batches
* TrainingData: Optimize color augmentation
* TrainingData: Optimize target and random_warp
* TrainingData - Convert _get_closest_match for batching
* TrainingData: Warp To Landmarks optimized
* Save models to threadpoolexecutor
* Move stack_images, Rename ImageManipulation. ImageAugmentation Docstrings
* Masks: Set dtype and threshold for lib.masks based on input face
* Docstrings and Documentation
2019-09-24 12:16:05 +01:00

720 lines
31 KiB
Python

#!/usr/bin/env python3
""" Base Trainer Class for Faceswap
Trainers should be inherited from this class.
A training_opts dictionary can be set in the corresponding model.
Accepted values:
alignments: dict containing paths to alignments files for keys 'a' and 'b'
preview_scaling: How much to scale the preview out by
training_size: Size of the training images
coverage_ratio: Ratio of face to be cropped out for training
mask_type: Type of mask to use. See lib.model.masks for valid mask names.
Set to None for not used
no_logs: Disable tensorboard logging
snapshot_interval: Interval for saving model snapshots
warp_to_landmarks: Use random_warp_landmarks instead of random_warp
augment_color: Perform random shifting of L*a*b* colors
no_flip: Don't perform a random flip on the image
pingpong: Train each side seperately per save iteration rather than together
"""
import logging
import os
import time
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.python import errors_impl as tf_errors # pylint:disable=no-name-in-module
from lib.alignments import Alignments
from lib.faces_detect import DetectedFace
from lib.training_data import TrainingDataGenerator
from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.train._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name, configfile=None):
""" Return the config for the requested model """
return Config(plugin_name, configfile=configfile).config_dict
class TrainerBase():
""" Base Trainer """
def __init__(self, model, images, batch_size, configfile):
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
self.__class__.__name__, model, batch_size)
self.config = get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
self.batch_size = batch_size
self.model = model
self.model.state.add_session_batchsize(batch_size)
self.images = images
self.sides = sorted(key for key in self.images.keys())
self.process_training_opts()
self.pingpong = PingPong(model, self.sides)
self.batchers = {side: Batcher(side,
images[side],
self.model,
self.use_mask,
batch_size,
self.config)
for side in self.sides}
self.tensorboard = self.set_tensorboard()
self.samples = Samples(self.model,
self.use_mask,
self.model.training_opts["coverage_ratio"],
self.model.training_opts["preview_scaling"])
self.timelapse = Timelapse(self.model,
self.use_mask,
self.model.training_opts["coverage_ratio"],
self.config.get("preview_images", 14),
self.batchers)
logger.debug("Initialized %s", self.__class__.__name__)
@property
def timestamp(self):
""" Standardised timestamp for loss reporting """
return time.strftime("%H:%M:%S")
@property
def landmarks_required(self):
""" Return True if Landmarks are required """
opts = self.model.training_opts
retval = bool(opts.get("mask_type", None) or opts["warp_to_landmarks"])
logger.debug(retval)
return retval
@property
def use_mask(self):
""" Return True if a mask is requested """
retval = bool(self.model.training_opts.get("mask_type", None))
logger.debug(retval)
return retval
def process_training_opts(self):
""" Override for processing model specific training options """
logger.debug(self.model.training_opts)
if self.landmarks_required:
landmarks = Landmarks(self.model.training_opts).landmarks
self.model.training_opts["landmarks"] = landmarks
def set_tensorboard(self):
""" Set up tensorboard callback """
if self.model.training_opts["no_logs"]:
logger.verbose("TensorBoard logging disabled")
return None
if self.pingpong.active:
# Currently TensorBoard uses the tf.session, meaning that VRAM does not
# get cleared when model switching
# TODO find a fix for this
logger.warning("Currently TensorBoard logging is not supported for Ping-Pong "
"training. Session stats and graphing will not be available for this "
"training session.")
return None
logger.debug("Enabling TensorBoard Logging")
tensorboard = dict()
for side in self.sides:
logger.debug("Setting up TensorBoard Logging. Side: %s", side)
log_dir = os.path.join(str(self.model.model_dir),
"{}_logs".format(self.model.name),
side,
"session_{}".format(self.model.state.session_id))
tbs = tf.keras.callbacks.TensorBoard(log_dir=log_dir, **self.tensorboard_kwargs)
tbs.set_model(self.model.predictors[side])
tensorboard[side] = tbs
logger.info("Enabled TensorBoard Logging")
return tensorboard
@property
def tensorboard_kwargs(self):
""" TF 1.13 + needs an additional kwarg which is not valid for earlier versions """
kwargs = dict(histogram_freq=0, # Must be 0 or hangs
batch_size=64,
write_graph=True,
write_grads=True)
tf_version = [int(ver) for ver in tf.__version__.split(".") if ver.isdigit()]
logger.debug("Tensorflow version: %s", tf_version)
if tf_version[0] > 1 or (tf_version[0] == 1 and tf_version[1] > 12):
kwargs["update_freq"] = "batch"
if tf_version[0] > 1 or (tf_version[0] == 1 and tf_version[1] > 13):
kwargs["profile_batch"] = 0
logger.debug(kwargs)
return kwargs
def print_loss(self, loss):
""" Override for specific model loss formatting """
logger.trace(loss)
output = ["Loss {}: {:.5f}".format(side.capitalize(), loss[side][0])
for side in sorted(loss.keys())]
output = ", ".join(output)
print("[{}] [#{:05d}] {}".format(self.timestamp, self.model.iterations, output), end='\r')
def train_one_step(self, viewer, timelapse_kwargs):
""" Train a batch """
logger.trace("Training one step: (iteration: %s)", self.model.iterations)
do_preview = viewer is not None
do_timelapse = timelapse_kwargs is not None
snapshot_interval = self.model.training_opts.get("snapshot_interval", 0)
do_snapshot = (snapshot_interval != 0 and
self.model.iterations >= snapshot_interval and
self.model.iterations % snapshot_interval == 0)
loss = dict()
try:
for side, batcher in self.batchers.items():
if self.pingpong.active and side != self.pingpong.side:
continue
loss[side] = batcher.train_one_batch(do_preview)
if not do_preview and not do_timelapse:
continue
if do_preview:
self.samples.images[side] = batcher.compile_sample(None)
if do_timelapse:
self.timelapse.get_sample(side, timelapse_kwargs)
self.model.state.increment_iterations()
for side, side_loss in loss.items():
self.store_history(side, side_loss)
self.log_tensorboard(side, side_loss)
if not self.pingpong.active:
self.print_loss(loss)
else:
for key, val in loss.items():
self.pingpong.loss[key] = val
self.print_loss(self.pingpong.loss)
if do_preview:
samples = self.samples.show_sample()
if samples is not None:
viewer(samples, "Training - 'S': Save Now. 'ENTER': Save and Quit")
if do_timelapse:
self.timelapse.output_timelapse()
if do_snapshot:
self.model.do_snapshot()
except Exception as err:
raise err
def store_history(self, side, loss):
""" Store the history of this step """
logger.trace("Updating loss history: '%s'", side)
self.model.history[side].append(loss[0]) # Either only loss or total loss
logger.trace("Updated loss history: '%s'", side)
def log_tensorboard(self, side, loss):
""" Log loss to TensorBoard log """
if not self.tensorboard:
return
logger.trace("Updating TensorBoard log: '%s'", side)
logs = {log[0]: log[1]
for log in zip(self.model.state.loss_names[side], loss)}
self.tensorboard[side].on_batch_end(self.model.state.iterations, logs)
logger.trace("Updated TensorBoard log: '%s'", side)
def clear_tensorboard(self):
""" Indicate training end to Tensorboard """
if not self.tensorboard:
return
for side, tensorboard in self.tensorboard.items():
logger.debug("Ending Tensorboard. Side: '%s'", side)
tensorboard.on_train_end(None)
class Batcher():
""" Batch images from a single side """
def __init__(self, side, images, model, use_mask, batch_size, config):
logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s, config: %s)",
self.__class__.__name__, side, len(images), batch_size, config)
self.model = model
self.use_mask = use_mask
self.side = side
self.images = images
self.config = config
self.target = None
self.samples = None
self.mask = None
generator = self.load_generator()
self.feed = generator.minibatch_ab(images, batch_size, self.side)
self.preview_feed = None
self.timelapse_feed = None
def load_generator(self):
""" Pass arguments to TrainingDataGenerator and return object """
logger.debug("Loading generator: %s", self.side)
input_size = self.model.input_shape[0]
output_shapes = self.model.output_shapes
logger.debug("input_size: %s, output_shapes: %s", input_size, output_shapes)
generator = TrainingDataGenerator(input_size,
output_shapes,
self.model.training_opts,
self.config)
return generator
def train_one_batch(self, do_preview):
""" Train a batch """
logger.trace("Training one step: (side: %s)", self.side)
batch = self.get_next(do_preview)
try:
loss = self.model.predictors[self.side].train_on_batch(*batch)
except tf_errors.ResourceExhaustedError as err:
msg = ("You do not have enough GPU memory available to train the selected model at "
"the selected settings. You can try a number of things:"
"\n1) Close any other application that is using your GPU (web browsers are "
"particularly bad for this)."
"\n2) Lower the batchsize (the amount of images fed into the model each "
"iteration)."
"\n3) Try 'Memory Saving Gradients' and/or 'Optimizer Savings' and/or 'Ping "
"Pong Training'."
"\n4) Use a more lightweight model, or select the model's 'LowMem' option "
"(in config) if it has one.")
raise FaceswapError(msg) from err
loss = loss if isinstance(loss, list) else [loss]
return loss
def get_next(self, do_preview):
""" Return the next batch from the generator
Items should come out as: (warped, target [, mask]) """
batch = next(self.feed)
if self.use_mask:
batch = [[batch["feed"], batch["masks"]], batch["targets"] + [batch["masks"]]]
else:
batch = [batch["feed"], batch["targets"]]
self.generate_preview(do_preview)
return batch
def generate_preview(self, do_preview):
""" Generate the preview if a preview iteration """
if not do_preview:
self.samples = None
self.target = None
return
logger.debug("Generating preview")
if self.preview_feed is None:
self.set_preview_feed()
batch = next(self.preview_feed)
self.samples = batch["samples"]
self.target = [batch["targets"][self.model.largest_face_index]]
if self.use_mask:
self.target += [batch["masks"]]
def set_preview_feed(self):
""" Set the preview dictionary """
logger.debug("Setting preview feed: (side: '%s')", self.side)
preview_images = self.config.get("preview_images", 14)
preview_images = min(max(preview_images, 2), 16)
batchsize = min(len(self.images), preview_images)
self.preview_feed = self.load_generator().minibatch_ab(self.images,
batchsize,
self.side,
do_shuffle=True,
is_preview=True)
logger.debug("Set preview feed. Batchsize: %s", batchsize)
def compile_sample(self, batch_size, samples=None, images=None):
""" Training samples to display in the viewer """
num_images = self.config.get("preview_images", 14)
num_images = min(batch_size, num_images) if batch_size is not None else num_images
logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images)
images = images if images is not None else self.target
samples = [samples[0:num_images]] if samples is not None else [self.samples[0:num_images]]
if self.use_mask:
retval = [tgt[0:num_images] for tgt in images]
else:
retval = [images[0:num_images]]
retval = samples + retval
return retval
def compile_timelapse_sample(self):
""" Timelapse samples """
batch = next(self.timelapse_feed)
batchsize = len(batch["samples"])
images = [batch["targets"][self.model.largest_face_index]]
if self.use_mask:
images = images + [batch["masks"]]
sample = self.compile_sample(batchsize, samples=batch["samples"], images=images)
return sample
def set_timelapse_feed(self, images, batchsize):
""" Set the timelapse dictionary """
logger.debug("Setting timelapse feed: (side: '%s', input_images: '%s', batchsize: %s)",
self.side, images, batchsize)
self.timelapse_feed = self.load_generator().minibatch_ab(images[:batchsize],
batchsize, self.side,
do_shuffle=False,
is_timelapse=True)
logger.debug("Set timelapse feed")
class Samples():
""" Display samples for preview and timelapse """
def __init__(self, model, use_mask, coverage_ratio, scaling=1.0):
logger.debug("Initializing %s: model: '%s', use_mask: %s, coverage_ratio: %s)",
self.__class__.__name__, model, use_mask, coverage_ratio)
self.model = model
self.use_mask = use_mask
self.images = dict()
self.coverage_ratio = coverage_ratio
self.scaling = scaling
logger.debug("Initialized %s", self.__class__.__name__)
def show_sample(self):
""" Display preview data """
if len(self.images) != 2:
logger.debug("Ping Pong training - Only one side trained. Aborting preview")
return None
logger.debug("Showing sample")
feeds = dict()
figures = dict()
headers = dict()
for side, samples in self.images.items():
faces = samples[1]
if self.model.input_shape[0] / faces.shape[1] != 1.0:
feeds[side] = self.resize_sample(side, faces, self.model.input_shape[0])
feeds[side] = feeds[side].reshape((-1, ) + self.model.input_shape)
else:
feeds[side] = faces
if self.use_mask:
mask = samples[-1]
feeds[side] = [feeds[side], mask]
preds = self.get_predictions(feeds["a"], feeds["b"])
for side, samples in self.images.items():
other_side = "a" if side == "b" else "b"
predictions = [preds["{0}_{0}".format(side)],
preds["{}_{}".format(other_side, side)]]
display = self.to_full_frame(side, samples, predictions)
headers[side] = self.get_headers(side, display[0].shape[1])
figures[side] = np.stack([display[0], display[1], display[2], ], axis=1)
if self.images[side][0].shape[0] % 2 == 1:
figures[side] = np.concatenate([figures[side],
np.expand_dims(figures[side][0], 0)])
width = 4
side_cols = width // 2
if side_cols != 1:
headers = self.duplicate_headers(headers, side_cols)
header = np.concatenate([headers["a"], headers["b"]], axis=1)
figure = np.concatenate([figures["a"], figures["b"]], axis=0)
height = int(figure.shape[0] / width)
figure = figure.reshape((width, height) + figure.shape[1:])
figure = stack_images(figure)
figure = np.vstack((header, figure))
logger.debug("Compiled sample")
return np.clip(figure * 255, 0, 255).astype('uint8')
@staticmethod
def resize_sample(side, sample, target_size):
""" Resize samples where predictor expects different shape from processed image """
scale = target_size / sample.shape[1]
if scale == 1.0:
return sample
logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)",
side, sample.shape, target_size, scale)
interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint: disable=no-member
retval = np.array([cv2.resize(img, # pylint: disable=no-member
(target_size, target_size),
interpn)
for img in sample])
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
return retval
def get_predictions(self, feed_a, feed_b):
""" Return the sample predictions from the model """
logger.debug("Getting Predictions")
preds = dict()
preds["a_a"] = self.model.predictors["a"].predict(feed_a)
preds["b_a"] = self.model.predictors["b"].predict(feed_a)
preds["a_b"] = self.model.predictors["a"].predict(feed_b)
preds["b_b"] = self.model.predictors["b"].predict(feed_b)
# Get the returned largest image from predictors that emit multiple items
if not isinstance(preds["a_a"], np.ndarray):
for key, val in preds.items():
preds[key] = val[self.model.largest_face_index]
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
return preds
def to_full_frame(self, side, samples, predictions):
""" Patch the images into the full frame """
logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)",
side, len(samples), [pred.shape for pred in predictions])
full, faces = samples[:2]
images = [faces] + predictions
full_size = full.shape[1]
target_size = int(full_size * self.coverage_ratio)
if target_size != full_size:
frame = self.frame_overlay(full, target_size, (0, 0, 255))
if self.use_mask:
images = self.compile_masked(images, samples[-1])
images = [self.resize_sample(side, image, target_size) for image in images]
if target_size != full_size:
images = [self.overlay_foreground(frame, image) for image in images]
if self.scaling != 1.0:
new_size = int(full_size * self.scaling)
images = [self.resize_sample(side, image, new_size) for image in images]
return images
@staticmethod
def frame_overlay(images, target_size, color):
""" Add roi frame to a backfround image """
logger.debug("full_size: %s, target_size: %s, color: %s",
images.shape[1], target_size, color)
new_images = list()
full_size = images.shape[1]
padding = (full_size - target_size) // 2
length = target_size // 4
t_l, b_r = (padding, full_size - padding)
for img in images:
cv2.rectangle(img, # pylint: disable=no-member
(t_l, t_l),
(t_l + length, t_l + length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(b_r, t_l),
(b_r - length, t_l + length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(b_r, b_r),
(b_r - length,
b_r - length),
color,
3)
cv2.rectangle(img, # pylint: disable=no-member
(t_l, b_r),
(t_l + length, b_r - length),
color,
3)
new_images.append(img)
retval = np.array(new_images)
logger.debug("Overlayed background. Shape: %s", retval.shape)
return retval
@staticmethod
def compile_masked(faces, masks):
""" Add the mask to the faces for masked preview """
retval = list()
masks3 = np.tile(1 - np.rint(masks), 3)
for mask in masks3:
mask[np.where((mask == [1., 1., 1.]).all(axis=2))] = [0., 0., 1.]
for previews in faces:
images = np.array([cv2.addWeighted(img, 1.0, # pylint: disable=no-member
masks3[idx], 0.3,
0)
for idx, img in enumerate(previews)])
retval.append(images)
logger.debug("masked shapes: %s", [faces.shape for faces in retval])
return retval
@staticmethod
def overlay_foreground(backgrounds, foregrounds):
""" Overlay the training images into the center of the background """
offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2
new_images = list()
for idx, img in enumerate(backgrounds):
img[offset:offset + foregrounds[idx].shape[0],
offset:offset + foregrounds[idx].shape[1]] = foregrounds[idx]
new_images.append(img)
retval = np.array(new_images)
logger.debug("Overlayed foreground. Shape: %s", retval.shape)
return retval
def get_headers(self, side, width):
""" Set headers for images """
logger.debug("side: '%s', width: %s",
side, width)
titles = ("Original", "Swap") if side == "a" else ("Swap", "Original")
side = side.upper()
height = int(64 * self.scaling)
total_width = width * 3
logger.debug("height: %s, total_width: %s", height, total_width)
font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member
texts = ["{} ({})".format(titles[0], side),
"{0} > {0}".format(titles[0]),
"{} > {}".format(titles[0], titles[1])]
text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member
font,
self.scaling * 0.8,
1)[0]
for idx in range(len(texts))]
text_y = int((height + text_sizes[0][1]) / 2)
text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx
for idx in range(len(texts))]
logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s",
texts, text_sizes, text_x, text_y)
header_box = np.ones((height, total_width, 3), np.float32)
for idx, text in enumerate(texts):
cv2.putText(header_box, # pylint: disable=no-member
text,
(text_x[idx], text_y),
font,
self.scaling * 0.8,
(0, 0, 0),
1,
lineType=cv2.LINE_AA) # pylint: disable=no-member
logger.debug("header_box.shape: %s", header_box.shape)
return header_box
@staticmethod
def duplicate_headers(headers, columns):
""" Duplicate headers for the number of columns displayed """
for side, header in headers.items():
duped = tuple([header for _ in range(columns)])
headers[side] = np.concatenate(duped, axis=1)
logger.debug("side: %s header.shape: %s", side, header.shape)
return headers
class Timelapse():
""" Create the timelapse """
def __init__(self, model, use_mask, coverage_ratio, preview_images, batchers):
logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, "
"preview_images: %s, batchers: '%s')", self.__class__.__name__, model,
use_mask, coverage_ratio, preview_images, batchers)
self.preview_images = preview_images
self.samples = Samples(model, use_mask, coverage_ratio)
self.model = model
self.batchers = batchers
self.output_file = None
logger.debug("Initialized %s", self.__class__.__name__)
def get_sample(self, side, timelapse_kwargs):
""" Perform timelapse """
logger.debug("Getting timelapse samples: '%s'", side)
if not self.output_file:
self.setup(**timelapse_kwargs)
self.samples.images[side] = self.batchers[side].compile_timelapse_sample()
logger.debug("Got timelapse samples: '%s' - %s", side, len(self.samples.images[side]))
def setup(self, input_a=None, input_b=None, output=None):
""" Set the timelapse output folder """
logger.debug("Setting up timelapse")
if output is None:
output = str(get_folder(os.path.join(str(self.model.model_dir),
"{}_timelapse".format(self.model.name))))
self.output_file = str(output)
logger.debug("Timelapse output set to '%s'", self.output_file)
images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)}
batchsize = min(len(images["a"]),
len(images["b"]),
self.preview_images)
for side, image_files in images.items():
self.batchers[side].set_timelapse_feed(image_files, batchsize)
logger.debug("Set up timelapse")
def output_timelapse(self):
""" Set the timelapse dictionary """
logger.debug("Ouputting timelapse")
image = self.samples.show_sample()
if image is None:
return
filename = os.path.join(self.output_file, str(int(time.time())) + ".jpg")
cv2.imwrite(filename, image) # pylint: disable=no-member
logger.debug("Created timelapse: '%s'", filename)
class PingPong():
""" Side switcher for pingpong training """
def __init__(self, model, sides):
logger.debug("Initializing %s: (model: '%s')", self.__class__.__name__, model)
self.active = model.training_opts.get("pingpong", False)
self.model = model
self.sides = sides
self.side = sorted(sides)[0]
self.loss = {side: [0] for side in sides}
logger.debug("Initialized %s", self.__class__.__name__)
def switch(self):
""" Switch pingpong side """
if not self.active:
return
retval = [side for side in self.sides if side != self.side][0]
logger.info("Switching training to side %s", retval.title())
self.side = retval
self.reload_model()
def reload_model(self):
""" Load the model for just the current side """
logger.verbose("Ping-Pong re-loading model")
self.model.reset_pingpong()
class Landmarks():
""" Set Landmarks for training into the model's training options"""
def __init__(self, training_opts):
logger.debug("Initializing %s: (training_opts: '%s')",
self.__class__.__name__, training_opts)
self.size = training_opts.get("training_size", 256)
self.paths = training_opts["alignments"]
self.landmarks = self.get_alignments()
logger.debug("Initialized %s", self.__class__.__name__)
def get_alignments(self):
""" Obtain the landmarks for each faceset """
landmarks = dict()
for side, fullpath in self.paths.items():
path, filename = os.path.split(fullpath)
filename, extension = os.path.splitext(filename)
serializer = extension[1:]
alignments = Alignments(
path,
filename=filename,
serializer=serializer)
landmarks[side] = self.transform_landmarks(alignments)
return landmarks
def transform_landmarks(self, alignments):
""" For each face transform landmarks and return """
landmarks = dict()
for _, faces, _, _ in alignments.yield_faces():
for face in faces:
detected_face = DetectedFace()
detected_face.from_alignment(face)
detected_face.load_aligned(None, size=self.size)
landmarks[detected_face.hash] = detected_face.aligned_landmarks
return landmarks
def stack_images(images):
""" Stack images """
logger.debug("Stack images")
def get_transpose_axes(num):
if num % 2 == 0:
logger.debug("Even number of images to stack")
y_axes = list(range(1, num - 1, 2))
x_axes = list(range(0, num - 1, 2))
else:
logger.debug("Odd number of images to stack")
y_axes = list(range(0, num - 1, 2))
x_axes = list(range(1, num - 1, 2))
return y_axes, x_axes, [num - 1]
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
logger.debug("Stacked images")
return np.transpose(images, axes=np.concatenate(new_axes)).reshape(new_shape)