mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* model_refactor (#571) * original model to new structure * IAE model to new structure * OriginalHiRes to new structure * Fix trainer for different resolutions * Initial config implementation * Configparse library added * improved training data loader * dfaker model working * Add logging to training functions * Non blocking input for cli training * Add error handling to threads. Add non-mp queues to queue_handler * Improved Model Building and NNMeta * refactor lib/models * training refactor. DFL H128 model Implementation * Dfaker - use hashes * Move timelapse. Remove perceptual loss arg * Update INSTALL.md. Add logger formatting. Update Dfaker training * DFL h128 partially ported * Add mask to dfaker (#573) * Remove old models. Add mask to dfaker * dfl mask. Make masks selectable in config (#575) * DFL H128 Mask. Mask type selectable in config. * remove gan_v2_2 * Creating Input Size config for models Creating Input Size config for models Will be used downstream in converters. Also name change of image_shape to input_shape to clarify ( for future models with potentially different output_shapes) * Add mask loss options to config * MTCNN options to config.ini. Remove GAN config. Update USAGE.md * Add sliders for numerical values in GUI * Add config plugins menu to gui. Validate config * Only backup model if loss has dropped. Get training working again * bugfixes * Standardise loss printing * GUI idle cpu fixes. Graph loss fix. * mutli-gpu logging bugfix * Merge branch 'staging' into train_refactor * backup state file * Crash protection: Only backup if both total losses have dropped * Port OriginalHiRes_RC4 to train_refactor (OriginalHiRes) * Load and save model structure with weights * Slight code update * Improve config loader. Add subpixel opt to all models. Config to state * Show samples... wrong input * Remove AE topology. Add input/output shapes to State * Port original_villain (birb/VillainGuy) model to faceswap * Add plugin info to GUI config pages * Load input shape from state. IAE Config options. * Fix transform_kwargs. Coverage to ratio. Bugfix mask detection * Suppress keras userwarnings. Automate zoom. Coverage_ratio to model def. * Consolidation of converters & refactor (#574) * Consolidation of converters & refactor Initial Upload of alpha Items - consolidate convert_mased & convert_adjust into one converter -add average color adjust to convert_masked -allow mask transition blur size to be a fixed integer of pixels and a fraction of the facial mask size -allow erosion/dilation size to be a fixed integer of pixels and a fraction of the facial mask size -eliminate redundant type conversions to avoid multiple round-off errors -refactor loops for vectorization/speed -reorganize for clarity & style changes TODO - bug/issues with warping the new face onto a transparent old image...use a cleanup mask for now - issues with mask border giving black ring at zero erosion .. investigate - remove GAN ?? - test enlargment factors of umeyama standard face .. match to coverage factor - make enlargment factor a model parameter - remove convert_adjusted and referencing code when finished * Update Convert_Masked.py default blur size of 2 to match original... description of enlargement tests breakout matrxi scaling into def * Enlargment scale as a cli parameter * Update cli.py * dynamic interpolation algorithm Compute x & y scale factors from the affine matrix on the fly by QR decomp. Choose interpolation alogrithm for the affine warp based on an upsample or downsample for each image * input size input size from config * fix issues with <1.0 erosion * Update convert.py * Update Convert_Adjust.py more work on the way to merginf * Clean up help note on sharpen * cleanup seamless * Delete Convert_Adjust.py * Update umeyama.py * Update training_data.py * swapping * segmentation stub * changes to convert.str * Update masked.py * Backwards compatibility fix for models Get converter running * Convert: Move masks to class. bugfix blur_size some linting * mask fix * convert fixes - missing facehull_rect re-added - coverage to % - corrected coverage logic - cleanup of gui option ordering * Update cli.py * default for blur * Update masked.py * added preliminary low_mem version of OriginalHighRes model plugin * Code cleanup, minor fixes * Update masked.py * Update masked.py * Add dfl mask to convert * histogram fix & seamless location * update * revert * bugfix: Load actual configuration in gui * Standardize nn_blocks * Update cli.py * Minor code amends * Fix Original HiRes model * Add masks to preview output for mask trainers refactor trainer.__base.py * Masked trainers converter support * convert bugfix * Bugfix: Converter for masked (dfl/dfaker) trainers * Additional Losses (#592) * initial upload * Delete blur.py * default initializer = He instead of Glorot (#588) * Allow kernel_initializer to be overridable * Add ICNR Initializer option for upscale on all models. * Hopefully fixes RSoDs with original-highres model plugin * remove debug line * Original-HighRes model plugin Red Screen of Death fix, take #2 * Move global options to _base. Rename Villain model * clipnorm and res block biases * scale the end of res block * res block * dfaker pre-activation res * OHRES pre-activation * villain pre-activation * tabs/space in nn_blocks * fix for histogram with mask all set to zero * fix to prevent two networks with same name * GUI: Wider tooltips. Improve TQDM capture * Fix regex bug * Convert padding=48 to ratio of image size * Add size option to alignments tool extract * Pass through training image size to convert from model * Convert: Pull training coverage from model * convert: coverage, blur and erode to percent * simplify matrix scaling * ordering of sliders in train * Add matrix scaling to utils. Use interpolation in lib.aligner transform * masked.py Import get_matrix_scaling from utils * fix circular import * Update masked.py * quick fix for matrix scaling * testing thus for now * tqdm regex capture bugfix * Minor ammends * blur size cleanup * Remove coverage option from convert (Now cascades from model) * Implement convert for all model types * Add mask option and coverage option to all existing models * bugfix for model loading on convert * debug print removal * Bugfix for masks in dfl_h128 and iae * Update preview display. Add preview scaling to cli * mask notes * Delete training_data_v2.py errant file * training data variables * Fix timelapse function * Add new config items to state file for legacy purposes * Slight GUI tweak * Raise exception if problem with loaded model * Add Tensorboard support (Logs stored in model directory) * ICNR fix * loss bugfix * convert bugfix * Move ini files to config folder. Make TensorBoard optional * Fix training data for unbalanced inputs/outputs * Fix config "none" test * Keep helptext in .ini files when saving config from GUI * Remove frame_dims from alignments * Add no-flip and warp-to-landmarks cli options * Revert OHR to RC4_fix version * Fix lowmem mode on OHR model * padding to variable * Save models in parallel threads * Speed-up of res_block stability * Automated Reflection Padding * Reflect Padding as a training option Includes auto-calculation of proper padding shapes, input_shapes, output_shapes Flag included in config now * rest of reflect padding * Move TB logging to cli. Session info to state file * Add session iterations to state file * Add recent files to menu. GUI code tidy up * [GUI] Fix recent file list update issue * Add correct loss names to TensorBoard logs * Update live graph to use TensorBoard and remove animation * Fix analysis tab. GUI optimizations * Analysis Graph popup to Tensorboard Logs * [GUI] Bug fix for graphing for models with hypens in name * [GUI] Correctly split loss to tabs during training * [GUI] Add loss type selection to analysis graph * Fix store command name in recent files. Switch to correct tab on open * [GUI] Disable training graph when 'no-logs' is selected * Fix graphing race condition * rename original_hires model to unbalanced
401 lines
18 KiB
Python
401 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
""" Process training data for model training """
|
|
|
|
import logging
|
|
|
|
from hashlib import sha1
|
|
from random import shuffle
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from scipy.interpolate import griddata
|
|
|
|
from lib.model import masks
|
|
from lib.multithreading import MultiThread
|
|
from lib.queue_manager import queue_manager
|
|
from lib.umeyama import umeyama
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class TrainingDataGenerator():
|
|
""" Generate training data for models """
|
|
def __init__(self, model_input_size, model_output_size, training_opts):
|
|
logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, "
|
|
"training_opts: %s, landmarks: %s)",
|
|
self.__class__.__name__, model_input_size, model_output_size,
|
|
{key: val for key, val in training_opts.items() if key != "landmarks"},
|
|
bool(training_opts.get("landmarks", None)))
|
|
self.batchsize = 0
|
|
self.model_input_size = model_input_size
|
|
self.training_opts = training_opts
|
|
self.mask_function = self.set_mask_function()
|
|
self.landmarks = self.training_opts.get("landmarks", None)
|
|
|
|
self.processing = ImageManipulation(model_input_size,
|
|
model_output_size,
|
|
training_opts.get("coverage_ratio", 0.625))
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
def set_mask_function(self):
|
|
""" Set the mask function to use if using mask """
|
|
mask_type = self.training_opts.get("mask_type", None)
|
|
if mask_type:
|
|
logger.debug("Mask type: '%s'", mask_type)
|
|
mask_func = getattr(masks, mask_type)
|
|
else:
|
|
mask_func = None
|
|
logger.debug("Mask function: %s", mask_func)
|
|
return mask_func
|
|
|
|
def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False):
|
|
""" Keep a queue filled to 8x Batch Size """
|
|
logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, "
|
|
"is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse)
|
|
self.batchsize = batchsize
|
|
q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side)
|
|
q_size = batchsize * 8
|
|
# Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays
|
|
queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False)
|
|
load_thread = MultiThread(self.load_batches,
|
|
images,
|
|
q_name,
|
|
side,
|
|
is_timelapse,
|
|
do_shuffle)
|
|
load_thread.start()
|
|
logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name)
|
|
return self.minibatch(q_name, load_thread)
|
|
|
|
def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True):
|
|
""" Load the warped images and target images to queue """
|
|
logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', "
|
|
"is_timelapse: %s, do_shuffle: %s)",
|
|
len(images), q_name, side, is_timelapse, do_shuffle)
|
|
epoch = 0
|
|
queue = queue_manager.get_queue(q_name)
|
|
self.validate_samples(images)
|
|
while True:
|
|
if do_shuffle:
|
|
shuffle(images)
|
|
for img in images:
|
|
logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side)
|
|
queue.put(self.process_face(img, side, is_timelapse))
|
|
epoch += 1
|
|
logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')",
|
|
epoch, q_name, side)
|
|
|
|
def validate_samples(self, data):
|
|
""" Check the total number of images against batchsize and return
|
|
the total number of images """
|
|
length = len(data)
|
|
msg = ("Number of images is lower than batch-size (Note that too few "
|
|
"images may lead to bad training). # images: {}, "
|
|
"batch-size: {}".format(length, self.batchsize))
|
|
assert length >= self.batchsize, msg
|
|
|
|
def minibatch(self, q_name, load_thread):
|
|
""" A generator function that yields epoch, batchsize of warped_img
|
|
and batchsize of target_img from the load queue """
|
|
logger.debug("Launching minibatch generator for queue: '%s'", q_name)
|
|
queue = queue_manager.get_queue(q_name)
|
|
while True:
|
|
if load_thread.has_error:
|
|
logger.debug("Thread error detected")
|
|
break
|
|
batch = list()
|
|
for _ in range(self.batchsize):
|
|
images = queue.get()
|
|
for idx, image in enumerate(images):
|
|
if len(batch) < idx + 1:
|
|
batch.append(list())
|
|
batch[idx].append(image)
|
|
batch = [np.float32(image) for image in batch]
|
|
logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'",
|
|
len(batch), [item.shape for item in batch], q_name)
|
|
yield batch
|
|
logger.debug("Finished minibatch generator for queue: '%s'", q_name)
|
|
load_thread.join()
|
|
|
|
def process_face(self, filename, side, is_timelapse):
|
|
""" Load an image and perform transformation and warping """
|
|
logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)",
|
|
filename, side, is_timelapse)
|
|
try:
|
|
image = cv2.imread(filename) # pylint: disable=no-member
|
|
except TypeError:
|
|
raise Exception("Error while reading image", filename)
|
|
|
|
if self.mask_function or self.training_opts["warp_to_landmarks"]:
|
|
src_pts = self.get_landmarks(filename, image, side)
|
|
if self.mask_function:
|
|
image = self.mask_function(src_pts, image, channels=4)
|
|
|
|
image = self.processing.color_adjust(image)
|
|
|
|
if not is_timelapse:
|
|
image = self.processing.random_transform(image)
|
|
if not self.training_opts["no_flip"]:
|
|
image = self.processing.do_random_flip(image)
|
|
sample = image.copy()[:, :, :3]
|
|
|
|
if self.training_opts["warp_to_landmarks"]:
|
|
dst_pts = self.get_closest_match(filename, side, src_pts)
|
|
processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts)
|
|
else:
|
|
processed = self.processing.random_warp(image)
|
|
|
|
processed.insert(0, sample)
|
|
logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)",
|
|
filename, side, [img.shape for img in processed])
|
|
return processed
|
|
|
|
def get_landmarks(self, filename, image, side):
|
|
""" Return the landmarks for this face """
|
|
logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side)
|
|
lm_key = sha1(image).hexdigest()
|
|
try:
|
|
src_points = self.landmarks[side][lm_key]
|
|
except KeyError:
|
|
raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key,
|
|
filename))
|
|
logger.trace("Returning: (src_points: %s)", src_points)
|
|
return src_points
|
|
|
|
def get_closest_match(self, filename, side, src_points):
|
|
""" Return closest matched landmarks from opposite set """
|
|
logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'",
|
|
filename, src_points)
|
|
dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"]
|
|
dst_points = list(dst_points.values())
|
|
closest = (np.mean(np.square(src_points - dst_points),
|
|
axis=(1, 2))).argsort()[:10]
|
|
closest = np.random.choice(closest)
|
|
dst_points = dst_points[closest]
|
|
logger.trace("Returning: (dst_points: %s)", dst_points)
|
|
return dst_points
|
|
|
|
|
|
class ImageManipulation():
|
|
""" Manipulations to be performed on training images """
|
|
def __init__(self, input_size, output_size, coverage_ratio):
|
|
""" input_size: Size of the face input into the model
|
|
output_size: Size of the face that comes out of the modell
|
|
coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
|
"""
|
|
logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)",
|
|
self.__class__.__name__, input_size, output_size, coverage_ratio)
|
|
# Transform args
|
|
self.rotation_range = 10 # Range to randomly rotate the image by
|
|
self.zoom_range = 0.05 # Range to randomly zoom the image by
|
|
self.shift_range = 0.05 # Range to randomly translate the image by
|
|
self.random_flip = 0.5 # Chance to flip the image horizontally
|
|
# Transform and Warp args
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
# Warp args
|
|
self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160
|
|
self.scale = 5 # Normal random variable scale
|
|
logger.debug("Initialized %s", self.__class__.__name__)
|
|
|
|
@staticmethod
|
|
def color_adjust(img):
|
|
""" Color adjust RGB image """
|
|
logger.trace("Color adjusting image")
|
|
return img.astype('float32') / 255.0
|
|
|
|
@staticmethod
|
|
def separate_mask(image):
|
|
""" Return the image and the mask from a 4 channel image """
|
|
mask = None
|
|
if image.shape[2] == 4:
|
|
logger.trace("Image contains mask")
|
|
mask = np.expand_dims(image[:, :, -1], axis=2)
|
|
image = image[:, :, :3]
|
|
else:
|
|
logger.trace("Image has no mask")
|
|
return image, mask
|
|
|
|
def get_coverage(self, image):
|
|
""" Return coverage value for given image """
|
|
coverage = int(image.shape[0] * self.coverage_ratio)
|
|
logger.trace("Coverage: %s", coverage)
|
|
return coverage
|
|
|
|
def random_transform(self, image):
|
|
""" Randomly transform an image """
|
|
logger.trace("Randomly transforming image")
|
|
height, width = image.shape[0:2]
|
|
|
|
rotation = np.random.uniform(-self.rotation_range, self.rotation_range)
|
|
scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
|
|
tnx = np.random.uniform(-self.shift_range, self.shift_range) * width
|
|
tny = np.random.uniform(-self.shift_range, self.shift_range) * height
|
|
|
|
mat = cv2.getRotationMatrix2D( # pylint: disable=no-member
|
|
(width // 2, height // 2), rotation, scale)
|
|
mat[:, 2] += (tnx, tny)
|
|
result = cv2.warpAffine( # pylint: disable=no-member
|
|
image, mat, (width, height),
|
|
borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member
|
|
|
|
logger.trace("Randomly transformed image")
|
|
return result
|
|
|
|
def do_random_flip(self, image):
|
|
""" Perform flip on image if random number is within threshold """
|
|
logger.trace("Randomly flipping image")
|
|
if np.random.random() < self.random_flip:
|
|
logger.trace("Flip within threshold. Flipping")
|
|
retval = image[:, ::-1]
|
|
else:
|
|
logger.trace("Flip outside threshold. Not Flipping")
|
|
retval = image
|
|
logger.trace("Randomly flipped image")
|
|
return retval
|
|
|
|
def random_warp(self, image):
|
|
""" get pair of random warped images from aligned face image """
|
|
logger.trace("Randomly warping image")
|
|
height, width = image.shape[0:2]
|
|
coverage = self.get_coverage(image)
|
|
assert height == width and height % 2 == 0
|
|
|
|
range_ = np.linspace(height // 2 - coverage // 2,
|
|
height // 2 + coverage // 2,
|
|
5, dtype='float32')
|
|
mapx = np.broadcast_to(range_, (5, 5)).copy()
|
|
mapy = mapx.T
|
|
# mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast
|
|
|
|
pad = int(1.25 * self.input_size)
|
|
slices = slice(pad // 10, -pad // 10)
|
|
dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4))
|
|
interp = np.empty((2, self.input_size, self.input_size), dtype='float32')
|
|
####
|
|
|
|
for i, map_ in enumerate([mapx, mapy]):
|
|
map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale)
|
|
interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member
|
|
|
|
warped_image = cv2.remap( # pylint: disable=no-member
|
|
image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member
|
|
logger.trace("Warped image shape: %s", warped_image.shape)
|
|
|
|
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
|
|
dst_points = np.mgrid[dst_slice, dst_slice]
|
|
mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2]
|
|
target_image = cv2.warpAffine( # pylint: disable=no-member
|
|
image, mat, (self.output_size, self.output_size))
|
|
logger.trace("Target image shape: %s", target_image.shape)
|
|
|
|
warped_image, warped_mask = self.separate_mask(warped_image)
|
|
target_image, target_mask = self.separate_mask(target_image)
|
|
|
|
if target_mask is None:
|
|
logger.trace("Randomly warped image")
|
|
return [warped_image, target_image]
|
|
|
|
logger.trace("Target mask shape: %s", target_mask.shape)
|
|
logger.trace("Randomly warped image and mask")
|
|
return [warped_image, target_image, target_mask]
|
|
|
|
def random_warp_landmarks(self, image, src_points=None, dst_points=None):
|
|
""" get warped image, target image and target mask
|
|
From DFAKER plugin """
|
|
logger.trace("Randomly warping landmarks")
|
|
size = image.shape[0]
|
|
coverage = self.get_coverage(image)
|
|
|
|
p_mx = size - 1
|
|
p_hf = (size // 2) - 1
|
|
|
|
edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0),
|
|
(p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]
|
|
grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)]
|
|
|
|
source = src_points
|
|
destination = (dst_points.copy().astype('float32') +
|
|
np.random.normal(size=dst_points.shape, scale=2.0))
|
|
destination = destination.astype('uint8')
|
|
|
|
face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member
|
|
[source[17:], destination[17:]], axis=0).astype(int))
|
|
|
|
source = [(pty, ptx) for ptx, pty in source] + edge_anchors
|
|
destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors
|
|
|
|
indicies_to_remove = set()
|
|
for fpl in source, destination:
|
|
for idx, (pty, ptx) in enumerate(fpl):
|
|
if idx > 17:
|
|
break
|
|
elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member
|
|
(pty, ptx),
|
|
False) >= 0:
|
|
indicies_to_remove.add(idx)
|
|
|
|
for idx in sorted(indicies_to_remove, reverse=True):
|
|
source.pop(idx)
|
|
destination.pop(idx)
|
|
|
|
grid_z = griddata(destination, source, (grid_x, grid_y), method="linear")
|
|
map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size)
|
|
map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size)
|
|
map_x_32 = map_x.astype('float32')
|
|
map_y_32 = map_y.astype('float32')
|
|
|
|
warped_image = cv2.remap(image, # pylint: disable=no-member
|
|
map_x_32,
|
|
map_y_32,
|
|
cv2.INTER_LINEAR, # pylint: disable=no-member
|
|
cv2.BORDER_TRANSPARENT) # pylint: disable=no-member
|
|
target_image = image
|
|
|
|
# TODO Make sure this replacement is correct
|
|
slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2)
|
|
# slices = slice(size // 32, size - size // 32) # 8px on a 256px image
|
|
warped_image = cv2.resize( # pylint: disable=no-member
|
|
warped_image[slices, slices, :], (self.input_size, self.input_size),
|
|
cv2.INTER_AREA) # pylint: disable=no-member
|
|
logger.trace("Warped image shape: %s", warped_image.shape)
|
|
target_image = cv2.resize( # pylint: disable=no-member
|
|
target_image[slices, slices, :], (self.output_size, self.output_size),
|
|
cv2.INTER_AREA) # pylint: disable=no-member
|
|
logger.trace("Target image shape: %s", target_image.shape)
|
|
|
|
warped_image, warped_mask = self.separate_mask(warped_image)
|
|
target_image, target_mask = self.separate_mask(target_image)
|
|
|
|
if target_mask is None:
|
|
logger.trace("Randomly warped image")
|
|
return [warped_image, target_image]
|
|
|
|
logger.trace("Target mask shape: %s", target_mask.shape)
|
|
logger.trace("Randomly warped image and mask")
|
|
return [warped_image, target_image, target_mask]
|
|
|
|
|
|
def stack_images(images):
|
|
""" Stack images """
|
|
logger.debug("Stack images")
|
|
|
|
def get_transpose_axes(num):
|
|
if num % 2 == 0:
|
|
logger.debug("Even number of images to stack")
|
|
y_axes = list(range(1, num - 1, 2))
|
|
x_axes = list(range(0, num - 1, 2))
|
|
else:
|
|
logger.debug("Odd number of images to stack")
|
|
y_axes = list(range(0, num - 1, 2))
|
|
x_axes = list(range(1, num - 1, 2))
|
|
return y_axes, x_axes, [num - 1]
|
|
|
|
images_shape = np.array(images.shape)
|
|
new_axes = get_transpose_axes(len(images_shape))
|
|
new_shape = [np.prod(images_shape[x]) for x in new_axes]
|
|
logger.debug("Stacked images")
|
|
return np.transpose(
|
|
images,
|
|
axes=np.concatenate(new_axes)
|
|
).reshape(new_shape)
|