1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/tests/lib/model/normalization_test.py
torzdf aa39234538
Update all Keras Imports to be conditional (#1214)
* Remove custom keras importer

* first round keras imports fix

* launcher.py: Remove KerasFinder references

* 2nd round keras imports update (lib and extract)

* 3rd round keras imports update (train)

* remove KerasFinder from tests

* 4th round keras imports update (tests)
2022-05-03 20:18:39 +01:00

116 lines
4.5 KiB
Python

#!/usr/bin/env python3
""" Tests for Faceswap Normalization.
Adapted from Keras tests.
"""
from itertools import product
import numpy as np
import pytest
from lib.model import normalization
from lib.utils import get_backend
from tests.lib.model.layers_test import layer_test
if get_backend() == "amd":
from keras import regularizers, models, layers
else:
# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.keras import regularizers, models, layers # pylint:disable=import-error
@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()])
def test_instance_normalization(dummy): # pylint:disable=unused-argument
""" Basic test for instance normalization. """
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'gamma_regularizer': regularizers.l2(0.01),
'beta_regularizer': regularizers.l2(0.01)},
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'axis': 1},
input_shape=(1, 4, 1))
layer_test(normalization.InstanceNormalization,
kwargs={'gamma_initializer': 'ones',
'beta_initializer': 'ones'},
input_shape=(3, 4, 2, 4))
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'axis': 1,
'scale': False,
'center': False},
input_shape=(3, 4, 2, 4))
@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()])
def test_group_normalization(dummy): # pylint:disable=unused-argument
""" Basic test for instance normalization. """
layer_test(normalization.GroupNormalization,
kwargs={'epsilon': 0.1,
'gamma_regularizer': regularizers.l2(0.01),
'beta_regularizer': regularizers.l2(0.01)},
input_shape=(4, 3, 4, 128))
layer_test(normalization.GroupNormalization,
kwargs={'epsilon': 0.1,
'axis': 1},
input_shape=(4, 1, 4, 256))
layer_test(normalization.GroupNormalization,
kwargs={'gamma_init': 'ones',
'beta_init': 'ones'},
input_shape=(4, 64))
layer_test(normalization.GroupNormalization,
kwargs={'epsilon': 0.1,
'axis': 1,
'group': 16},
input_shape=(3, 64))
_PARAMS = ["center", "scale"]
_VALUES = list(product([True, False], repeat=len(_PARAMS)))
_IDS = ["{}[{}]".format("|".join([_PARAMS[idx] for idx, b in enumerate(v) if b]),
get_backend().upper()) for v in _VALUES]
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_adain_normalization(center, scale):
""" Basic test for Ada Instance Normalization. """
norm = normalization.AdaInstanceNormalization(center=center, scale=scale)
shapes = [(4, 8, 8, 1280), (4, 1, 1, 1280), (4, 1, 1, 1280)]
norm.build(shapes)
expected_output_shape = norm.compute_output_shape(shapes)
inputs = [layers.Input(shape=shapes[0][1:]),
layers.Input(shape=shapes[1][1:]),
layers.Input(shape=shapes[2][1:])]
model = models.Model(inputs, norm(inputs))
data = [10 * np.random.random(shape) for shape in shapes]
actual_output = model.predict(data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(expected_output_shape,
actual_output_shape):
if expected_dim is not None:
assert expected_dim == actual_dim
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_layer_normalization(center, scale):
""" Basic test for layer normalization. """
layer_test(normalization.LayerNormalization,
kwargs={"center": center, "scale": scale},
input_shape=(4, 512))
_PARAMS = ["partial", "bias"]
_VALUES = [(0.0, False), (0.25, False), (0.5, True), (0.75, False), (1.0, True)]
_IDS = [f"partial={v[0]}|bias={v[1]}[{get_backend().upper()}]" for v in _VALUES]
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_rms_normalization(partial, bias): # pylint:disable=unused-argument
""" Basic test for RMS Layer normalization. """
layer_test(normalization.RMSNormalization,
kwargs={"partial": partial, "bias": bias},
input_shape=(4, 512))