mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
* Core Updates - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant - Document lib.gpu_stats and lib.sys_info - Remove call to GPUStats.is_plaidml from convert and replace with get_backend() - lib.gui.menu - typofix * Update Dependencies Bump Tensorflow Version Check * Port extraction to tf2 * Add custom import finder for loading Keras or tf.keras depending on backend * Add `tensorflow` to KerasFinder search path * Basic TF2 training running * model.initializers - docstring fix * Fix and pass tests for tf2 * Replace Keras backend tests with faceswap backend tests * Initial optimizers update * Monkey patch tf.keras optimizer * Remove custom Adam Optimizers and Memory Saving Gradients * Remove multi-gpu option. Add Distribution to cli * plugins.train.model._base: Add Mirror, Central and Default distribution strategies * Update tensorboard kwargs for tf2 * Penalized Loss - Fix for TF2 and AMD * Fix syntax for tf2.1 * requirements typo fix * Explicit None for clipnorm if using a distribution strategy * Fix penalized loss for distribution strategies * Update Dlight * typo fix * Pin to TF2.2 * setup.py - Install tensorflow from pip if not available in Conda * Add reduction options and set default for mirrored distribution strategy * Explicitly use default strategy rather than nullcontext * lib.model.backup_restore documentation * Remove mirrored strategy reduction method and default based on OS * Initial restructure - training * Remove PingPong Start model.base refactor * Model saving and resuming enabled * More tidying up of model.base * Enable backup and snapshotting * Re-enable state file Remove loss names from state file Fix print loss function Set snapshot iterations correctly * Revert original model to Keras Model structure rather than custom layer Output full model and sub model summary Change NNBlocks to callables rather than custom keras layers * Apply custom Conv2D layer * Finalize NNBlock restructure Update Dfaker blocks * Fix reloading model under a different distribution strategy * Pass command line arguments through to trainer * Remove training_opts from model and reference params directly * Tidy up model __init__ * Re-enable tensorboard logging Suppress "Model Not Compiled" warning * Fix timelapse * lib.model.nnblocks - Bugfix residual block Port dfaker bugfix original * dfl-h128 ported * DFL SAE ported * IAE Ported * dlight ported * port lightweight * realface ported * unbalanced ported * villain ported * lib.cli.args - Update Batchsize + move allow_growth to config * Remove output shape definition Get image sizes per side rather than globally * Strip mask input from encoder * Fix learn mask and output learned mask to preview * Trigger Allow Growth prior to setting strategy * Fix GUI Graphing * GUI - Display batchsize correctly + fix training graphs * Fix penalized loss * Enable mixed precision training * Update analysis displayed batch to match input * Penalized Loss - Multi-GPU Fix * Fix all losses for TF2 * Fix Reflect Padding * Allow different input size for each side of the model * Fix conv-aware initialization on reload * Switch allow_growth order * Move mixed_precision to cli * Remove distrubution strategies * Compile penalized loss sub-function into LossContainer * Bump default save interval to 250 Generate preview on first iteration but don't save Fix iterations to start at 1 instead of 0 Remove training deprecation warnings Bump some scripts.train loglevels * Add ability to refresh preview on demand on pop-up window * Enable refresh of training preview from GUI * Fix Convert Debug logging in Initializers * Fix Preview Tool * Update Legacy TF1 weights to TF2 Catch stats error on loading stats with missing logs * lib.gui.popup_configure - Make more responsive + document * Multiple Outputs supported in trainer Original Model - Mask output bugfix * Make universal inference model for convert Remove scaling from penalized mask loss (now handled at input to y_true) * Fix inference model to work properly with all models * Fix multi-scale output for convert * Fix clipnorm issue with distribution strategies Edit error message on OOM * Update plaidml losses * Add missing file * Disable gmsd loss for plaidnl * PlaidML - Basic training working * clipnorm rewriting for mixed-precision * Inference model creation bugfixes * Remove debug code * Bugfix: Default clipnorm to 1.0 * Remove all mask inputs from training code * Remove mask inputs from convert * GUI - Analysis Tab - Docstrings * Fix rate in totals row * lib.gui - Only update display pages if they have focus * Save the model on first iteration * plaidml - Fix SSIM loss with penalized loss * tools.alignments - Remove manual and fix jobs * GUI - Remove case formatting on help text * gui MultiSelect custom widget - Set default values on init * vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class cli - Add global GPU Exclude Option tools.sort - Use global GPU Exlude option for backend lib.model.session - Exclude all GPUs when running in CPU mode lib.cli.launcher - Set backend to CPU mode when all GPUs excluded * Cascade excluded devices to GPU Stats * Explicit GPU selection for Train and Convert * Reduce Tensorflow Min GPU Multiprocessor Count to 4 * remove compat.v1 code from extract * Force TF to skip mixed precision compatibility check if GPUs have been filtered * Add notes to config for non-working AMD losses * Rasie error if forcing extract to CPU mode * Fix loading of legace dfl-sae weights + dfl-sae typo fix * Remove unused requirements Update sphinx requirements Fix broken rst file locations * docs: lib.gui.display * clipnorm amd condition check * documentation - gui.display_analysis * Documentation - gui.popup_configure * Documentation - lib.logger * Documentation - lib.model.initializers * Documentation - lib.model.layers * Documentation - lib.model.losses * Documentation - lib.model.nn_blocks * Documetation - lib.model.normalization * Documentation - lib.model.session * Documentation - lib.plaidml_stats * Documentation: lib.training_data * Documentation: lib.utils * Documentation: plugins.train.model._base * GUI Stats: prevent stats from using GPU * Documentation - Original Model * Documentation: plugins.model.trainer._base * linting * unit tests: initializers + losses * unit tests: nn_blocks * bugfix - Exclude gpu devices in train, not include * Enable Exclude-Gpus in Extract * Enable exclude gpus in tools * Disallow multiple plugin types in a single model folder * Automatically add exclude_gpus argument in for cpu backends * Cpu backend fixes * Relax optimizer test threshold * Default Train settings - Set mask to Extended * Update Extractor cli help text Update to Python 3.8 * Fix FAN to run on CPU * lib.plaidml_tools - typofix * Linux installer - check for curl * linux installer - typo fix
564 lines
20 KiB
Python
564 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
""" S3FD Face detection plugin
|
|
https://arxiv.org/abs/1708.05237
|
|
|
|
Adapted from S3FD Port in FAN:
|
|
https://github.com/1adrianb/face-alignment
|
|
"""
|
|
|
|
from scipy.special import logsumexp
|
|
import numpy as np
|
|
import keras # pylint:disable=import-error
|
|
import keras.backend as K # pylint:disable=import-error
|
|
|
|
from lib.model.session import KSession
|
|
from ._base import Detector, logger
|
|
|
|
|
|
class Detect(Detector):
|
|
""" S3FD detector for face recognition """
|
|
def __init__(self, **kwargs):
|
|
git_model_id = 11
|
|
model_filename = "s3fd_keras_v1.h5"
|
|
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
|
|
self.name = "S3FD"
|
|
self.input_size = 640
|
|
self.vram = 4112
|
|
self.vram_warnings = 1024 # Will run at this with warnings
|
|
self.vram_per_batch = 208
|
|
self.batchsize = self.config["batch-size"]
|
|
|
|
def init_model(self):
|
|
""" Initialize S3FD Model"""
|
|
confidence = self.config["confidence"] / 100
|
|
model_kwargs = dict(custom_objects=dict(O2K_Add=AddO2K,
|
|
O2K_Slice=SliceO2K,
|
|
O2K_Sum=SumO2K,
|
|
O2K_Sqrt=SqrtO2K,
|
|
O2K_Pow=PowO2K,
|
|
O2K_ConstantLayer=ConstantLayerO2K,
|
|
O2K_Div=DivO2K))
|
|
self.model = S3fd(self.model_path,
|
|
model_kwargs,
|
|
self.config["allow_growth"],
|
|
self._exclude_gpus,
|
|
confidence)
|
|
|
|
def process_input(self, batch):
|
|
""" Compile the detection image(s) for prediction """
|
|
batch["feed"] = self.model.prepare_batch(batch["image"])
|
|
return batch
|
|
|
|
def predict(self, batch):
|
|
""" Run model to get predictions """
|
|
predictions = self.model.predict(batch["feed"])
|
|
batch["prediction"] = self.model.finalize_predictions(predictions)
|
|
logger.trace("filename: %s, prediction: %s", batch["filename"], batch["prediction"])
|
|
return batch
|
|
|
|
def process_output(self, batch):
|
|
""" Compile found faces for output """
|
|
return batch
|
|
|
|
|
|
################################################################################
|
|
# CUSTOM KERAS LAYERS
|
|
# generated by onnx2keras
|
|
################################################################################
|
|
class ElementwiseLayerO2K(keras.layers.Layer):
|
|
""" Custom Keras Element Wise layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Override for layers that inherit from this class.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def compute_output_shape(self, input_shape): # pylint:disable=no-self-use
|
|
"""Computes the output shape of the layer.
|
|
|
|
Assumes that the layer will be built to match that input shape provided.
|
|
|
|
Parameters
|
|
----------
|
|
input_shape: tuple or list of tuples
|
|
Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the
|
|
layer). Shape tuples can include ``None`` for free dimensions, instead of an integer.
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
An output shape tuple.
|
|
"""
|
|
# TODO: do this nicer
|
|
ldims = len(input_shape[0])
|
|
rdims = len(input_shape[1])
|
|
if ldims > rdims:
|
|
return input_shape[0]
|
|
if rdims > ldims:
|
|
return input_shape[1]
|
|
lprod = np.prod(list(filter(bool, input_shape[0])))
|
|
rprod = np.prod(list(filter(bool, input_shape[1])))
|
|
return input_shape[0 if lprod > rprod else 1]
|
|
|
|
|
|
class AddO2K(ElementwiseLayerO2K):
|
|
""" Custom Keras Add layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
return inputs[0] + inputs[1]
|
|
|
|
|
|
class SliceO2K(keras.layers.Layer):
|
|
""" Custom Keras Slice layer generated by onnx2keras. """
|
|
def __init__(self, starts, ends, axes=None, steps=None, **kwargs):
|
|
self._starts = starts
|
|
self._ends = ends
|
|
self._axes = axes
|
|
self._steps = steps
|
|
super().__init__(**kwargs)
|
|
|
|
def get_config(self):
|
|
""" Returns the config of the layer.
|
|
|
|
A layer config is a Python dictionary (serializable) containing the configuration of a
|
|
layer. The same layer can be re-instantiated later (without its trained weights) from this
|
|
configuration. The config of a layer does not include connectivity information, nor the
|
|
layer class name. These are handled by `Network` (one layer of abstraction above).
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The configuration for the layer
|
|
"""
|
|
config = super().get_config()
|
|
config.update({
|
|
'starts': self._starts, 'ends': self._ends,
|
|
'axes': self._axes, 'steps': self._steps
|
|
})
|
|
return config
|
|
|
|
def _get_slices(self, dimensions):
|
|
""" Obtain slices for the given number of dimensions.
|
|
|
|
Parameters
|
|
----------
|
|
dimensions: int
|
|
The number of dimensions to obtain slices for
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The slices for the given number of dimensions
|
|
"""
|
|
axes = self._axes
|
|
steps = self._steps
|
|
if axes is None:
|
|
axes = tuple(range(dimensions))
|
|
if steps is None:
|
|
steps = (1,) * len(axes)
|
|
assert len(axes) == len(steps) == len(self._starts) == len(self._ends)
|
|
return list(zip(axes, self._starts, self._ends, steps))
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
"""Computes the output shape of the layer.
|
|
|
|
Assumes that the layer will be built to match that input shape provided.
|
|
|
|
Parameters
|
|
----------
|
|
input_shape: tuple or list of tuples
|
|
Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the
|
|
layer). Shape tuples can include ``None`` for free dimensions, instead of an integer.
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
An output shape tuple.
|
|
"""
|
|
input_shape = list(input_shape)
|
|
for a_x, start, end, steps in self._get_slices(len(input_shape)):
|
|
size = input_shape[a_x]
|
|
if a_x == 0:
|
|
raise AttributeError("Can not slice batch axis.")
|
|
if size is None:
|
|
if start < 0 or end < 0:
|
|
raise AttributeError("Negative slices not supported on symbolic axes")
|
|
logger.warning("Slicing symbolic axis might lead to problems.")
|
|
input_shape[a_x] = (end - start) // steps
|
|
continue
|
|
if start < 0:
|
|
start = size - start
|
|
if end < 0:
|
|
end = size - end
|
|
input_shape[a_x] = (min(size, end) - start) // steps
|
|
return tuple(input_shape)
|
|
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
ax_map = dict((x[0], slice(*x[1:])) for x in self._get_slices(K.ndim(inputs)))
|
|
shape = K.int_shape(inputs)
|
|
slices = [(ax_map[a] if a in ax_map else slice(None)) for a in range(len(shape))]
|
|
retval = inputs[tuple(slices)]
|
|
return retval
|
|
|
|
|
|
class ReduceLayerO2K(keras.layers.Layer):
|
|
""" Custom Keras Reduce layer generated by onnx2keras. """
|
|
def __init__(self, axes=None, keepdims=True, **kwargs):
|
|
self._axes = [axes] if isinstance(axes, int) else axes
|
|
self._keepdims = bool(keepdims)
|
|
super().__init__(**kwargs)
|
|
|
|
def get_config(self):
|
|
""" Returns the config of the layer.
|
|
|
|
A layer config is a Python dictionary (serializable) containing the configuration of a
|
|
layer. The same layer can be re-instantiated later (without its trained weights) from this
|
|
configuration. The config of a layer does not include connectivity information, nor the
|
|
layer class name. These are handled by `Network` (one layer of abstraction above).
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The configuration for the layer
|
|
"""
|
|
config = super().get_config()
|
|
config.update({
|
|
'axes': self._axes,
|
|
'keepdims': self._keepdims
|
|
})
|
|
return config
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
"""Computes the output shape of the layer.
|
|
|
|
Assumes that the layer will be built to match that input shape provided.
|
|
|
|
Parameters
|
|
----------
|
|
input_shape: tuple or list of tuples
|
|
Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the
|
|
layer). Shape tuples can include ``None`` for free dimensions, instead of an integer.
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
An output shape tuple.
|
|
"""
|
|
if self._axes is None:
|
|
return (1,)*len(input_shape) if self._keepdims else tuple()
|
|
ret = list(input_shape)
|
|
for i in sorted(self._axes, reverse=True):
|
|
if self._keepdims:
|
|
ret[i] = 1
|
|
else:
|
|
ret.pop(i)
|
|
return tuple(ret)
|
|
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Override for layers which inherit from this class
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class SumO2K(ReduceLayerO2K):
|
|
""" Custom Keras Sum layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
return K.sum(inputs, self._axes, self._keepdims)
|
|
|
|
|
|
class SqrtO2K(keras.layers.Layer): # pylint:disable=too-few-public-methods
|
|
""" Custom Keras Square Root layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument,no-self-use
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
return K.sqrt(inputs)
|
|
|
|
|
|
class PowO2K(keras.layers.Layer): # pylint:disable=too-few-public-methods
|
|
""" Custom Keras Power layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument,no-self-use
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
return K.pow(*inputs)
|
|
|
|
|
|
class ConstantLayerO2K(keras.layers.Layer):
|
|
""" Custom Keras Constant layer generated by onnx2keras. """
|
|
def __init__(self, constant_obj, dtype, **kwargs):
|
|
self._dtype = np.dtype(dtype).name
|
|
self._constant = np.array(constant_obj, dtype=self._dtype)
|
|
super().__init__(**kwargs)
|
|
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer. Required for parent class but unused
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
data = K.constant(self._constant, dtype=self._dtype)
|
|
return data
|
|
|
|
def compute_output_shape(self, input_shape): # pylint:disable=unused-argument
|
|
"""Computes the output shape of the layer.
|
|
|
|
Assumes that the layer will be built to match that input shape provided.
|
|
|
|
Parameters
|
|
----------
|
|
input_shape: tuple or list of tuples
|
|
Shape tuple (tuple of integers) or list of shape tuples (one per output tensor of the
|
|
layer). Shape tuples can include ``None`` for free dimensions, instead of an integer.
|
|
This is unused for a constant layer
|
|
|
|
Returns
|
|
-------
|
|
tuple
|
|
An output shape tuple.
|
|
"""
|
|
return self._constant.shape
|
|
|
|
def get_config(self):
|
|
""" Returns the config of the layer.
|
|
|
|
A layer config is a Python dictionary (serializable) containing the configuration of a
|
|
layer. The same layer can be re-instantiated later (without its trained weights) from this
|
|
configuration. The config of a layer does not include connectivity information, nor the
|
|
layer class name. These are handled by `Network` (one layer of abstraction above).
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The configuration for the layer
|
|
"""
|
|
config = super().get_config()
|
|
config.update({
|
|
'constant_obj': self._constant,
|
|
'dtype': self._dtype
|
|
})
|
|
return config
|
|
|
|
|
|
class DivO2K(ElementwiseLayerO2K):
|
|
""" Custom Keras Division layer generated by onnx2keras. """
|
|
def call(self, inputs, **kwargs): # pylint:disable=unused-argument
|
|
"""This is where the layer's logic lives.
|
|
|
|
Parameters
|
|
----------
|
|
inputs: Input tensor, or list/tuple of input tensors.
|
|
The input to the layer
|
|
**kwargs: Additional keyword arguments.
|
|
Required for parent class but unused
|
|
Returns
|
|
-------
|
|
A tensor or list/tuple of tensors.
|
|
The layer output
|
|
"""
|
|
return inputs[0] / inputs[1]
|
|
|
|
|
|
class S3fd(KSession):
|
|
""" Keras Network """
|
|
def __init__(self, model_path, model_kwargs, allow_growth, exclude_gpus, confidence):
|
|
logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, "
|
|
"exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path,
|
|
model_kwargs, allow_growth, exclude_gpus, confidence)
|
|
super().__init__("S3FD",
|
|
model_path,
|
|
model_kwargs=model_kwargs,
|
|
allow_growth=allow_growth,
|
|
exclude_gpus=exclude_gpus)
|
|
self.load_model()
|
|
self.confidence = confidence
|
|
self.average_img = np.array([104.0, 117.0, 123.0])
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
def prepare_batch(self, batch):
|
|
""" Prepare a batch for prediction """
|
|
batch = batch - self.average_img
|
|
batch = batch.transpose(0, 3, 1, 2)
|
|
return batch
|
|
|
|
def finalize_predictions(self, bounding_boxes_scales):
|
|
""" Detect faces """
|
|
ret = list()
|
|
batch_size = range(bounding_boxes_scales[0].shape[0])
|
|
for img in batch_size:
|
|
bboxlist = [scale[img:img+1] for scale in bounding_boxes_scales]
|
|
boxes = self._post_process(bboxlist)
|
|
bboxlist = self._nms(boxes, 0.5)
|
|
ret.append(bboxlist)
|
|
return ret
|
|
|
|
def _post_process(self, bboxlist):
|
|
""" Perform post processing on output
|
|
TODO: do this on the batch.
|
|
"""
|
|
retval = list()
|
|
for i in range(len(bboxlist) // 2):
|
|
bboxlist[i * 2] = self.softmax(bboxlist[i * 2], axis=1)
|
|
for i in range(len(bboxlist) // 2):
|
|
ocls, oreg = bboxlist[i * 2], bboxlist[i * 2 + 1]
|
|
stride = 2 ** (i + 2) # 4,8,16,32,64,128
|
|
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
|
for _, hindex, windex in poss:
|
|
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
|
score = ocls[0, 1, hindex, windex]
|
|
if score >= self.confidence:
|
|
loc = np.ascontiguousarray(oreg[0, :, hindex, windex]).reshape((1, 4))
|
|
priors = np.array([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
|
box = self.decode(loc, priors)
|
|
x_1, y_1, x_2, y_2 = box[0] * 1.0
|
|
retval.append([x_1, y_1, x_2, y_2, score])
|
|
return_numpy = np.array(retval) if len(retval) != 0 else np.zeros((1, 5))
|
|
return return_numpy
|
|
|
|
@staticmethod
|
|
def softmax(inp, axis):
|
|
"""Compute softmax values for each sets of scores in x."""
|
|
return np.exp(inp - logsumexp(inp, axis=axis, keepdims=True))
|
|
|
|
@staticmethod
|
|
def decode(location, priors):
|
|
"""Decode locations from predictions using priors to undo the encoding we did for offset
|
|
regression at train time.
|
|
|
|
Parameters
|
|
----------
|
|
location: tensor
|
|
location predictions for location layers,
|
|
priors: tensor
|
|
Prior boxes in center-offset form.
|
|
|
|
Returns
|
|
-------
|
|
:class:`numpy.ndarray`
|
|
decoded bounding box predictions
|
|
"""
|
|
variances = [0.1, 0.2]
|
|
boxes = np.concatenate((priors[:, :2] + location[:, :2] * variances[0] * priors[:, 2:],
|
|
priors[:, 2:] * np.exp(location[:, 2:] * variances[1])), axis=1)
|
|
boxes[:, :2] -= boxes[:, 2:] / 2
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
return boxes
|
|
|
|
@staticmethod
|
|
def _nms(boxes, threshold):
|
|
""" Perform Non-Maximum Suppression """
|
|
retained_box_indices = list()
|
|
|
|
areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
|
|
ranked_indices = boxes[:, 4].argsort()[::-1]
|
|
while ranked_indices.size > 0:
|
|
best = ranked_indices[0]
|
|
rest = ranked_indices[1:]
|
|
|
|
max_of_xy = np.maximum(boxes[best, :2], boxes[rest, :2])
|
|
min_of_xy = np.minimum(boxes[best, 2:4], boxes[rest, 2:4])
|
|
width_height = np.maximum(0, min_of_xy - max_of_xy + 1)
|
|
intersection_areas = width_height[:, 0] * width_height[:, 1]
|
|
iou = intersection_areas / (areas[best] + areas[rest] - intersection_areas)
|
|
|
|
overlapping_boxes = (iou > threshold).nonzero()[0]
|
|
if len(overlapping_boxes) != 0:
|
|
overlap_set = ranked_indices[overlapping_boxes + 1]
|
|
vote = np.average(boxes[overlap_set, :4], axis=0, weights=boxes[overlap_set, 4])
|
|
boxes[best, :4] = vote
|
|
retained_box_indices.append(best)
|
|
|
|
non_overlapping_boxes = (iou <= threshold).nonzero()[0]
|
|
ranked_indices = ranked_indices[non_overlapping_boxes + 1]
|
|
return boxes[retained_box_indices]
|