1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 11:53:26 -04:00
faceswap/tests/lib/model/normalization_test.py
torzdf 02336977dc lib.model.normalization:
- refactor
  - add LayerNorm
  - add RMSNorm
2021-02-17 00:21:18 +00:00

66 lines
2.4 KiB
Python

#!/usr/bin/env python3
""" Tests for Faceswap Normalization.
Adapted from Keras tests.
"""
from itertools import product
from keras import regularizers
import pytest
from lib.model import normalization
from lib.utils import get_backend
from tests.lib.model.layers_test import layer_test
@pytest.mark.parametrize('dummy', [None], ids=[get_backend().upper()])
def test_instance_normalization(dummy): # pylint:disable=unused-argument
""" Basic test for instance normalization. """
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'gamma_regularizer': regularizers.l2(0.01),
'beta_regularizer': regularizers.l2(0.01)},
input_shape=(3, 4, 2))
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'axis': 1},
input_shape=(1, 4, 1))
layer_test(normalization.InstanceNormalization,
kwargs={'gamma_initializer': 'ones',
'beta_initializer': 'ones'},
input_shape=(3, 4, 2, 4))
layer_test(normalization.InstanceNormalization,
kwargs={'epsilon': 0.1,
'axis': 1,
'scale': False,
'center': False},
input_shape=(3, 4, 2, 4))
_PARAMS = ["center", "scale"]
_VALUES = list(product([True, False], repeat=len(_PARAMS)))
_IDS = ["{}[{}]".format("|".join([_PARAMS[idx] for idx, b in enumerate(v) if b]),
get_backend().upper()) for v in _VALUES]
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_layer_normalization(center, scale): # pylint:disable=unused-argument
""" Basic test for layer normalization. """
layer_test(normalization.LayerNormalization,
kwargs={"center": center, "scale": scale},
input_shape=(4, 512))
_PARAMS = ["partial", "bias"]
_VALUES = [(0.0, False), (0.25, False), (0.5, True), (0.75, False), (1.0, True)]
_IDS = ["partial={}|bias={}[{}]".format(v[0], v[1], get_backend().upper())
for v in _VALUES]
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_rms_normalization(partial, bias): # pylint:disable=unused-argument
""" Basic test for RMS Layer normalization. """
layer_test(normalization.RMSNormalization,
kwargs={"partial": partial, "bias": bias},
input_shape=(4, 512))