1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00
faceswap/tests/lib/model/nn_blocks_test.py
torzdf aa39234538
Update all Keras Imports to be conditional (#1214)
* Remove custom keras importer

* first round keras imports fix

* launcher.py: Remove KerasFinder references

* 2nd round keras imports update (lib and extract)

* 3rd round keras imports update (train)

* remove KerasFinder from tests

* 4th round keras imports update (tests)
2022-05-03 20:18:39 +01:00

84 lines
3 KiB
Python

#!/usr/bin/env python3
""" Tests for Faceswap Custom Layers.
Adapted from Keras tests.
"""
from itertools import product
import pytest
import numpy as np
from numpy.testing import assert_allclose
from lib.model import nn_blocks
from lib.utils import get_backend
if get_backend() == "amd":
from keras import Input, Model, backend as K
else:
# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.keras import Input, Model, backend as K # pylint:disable=import-error
def block_test(layer_func, kwargs={}, input_shape=None):
"""Test routine for faceswap neural network blocks.
Tests are simple and are to ensure that the blocks compile on both tensorflow
and plaidml backends
"""
# generate input data
assert input_shape
input_dtype = K.floatx()
input_data_shape = list(input_shape)
for i, var_e in enumerate(input_data_shape):
if var_e is None:
input_data_shape[i] = np.random.randint(1, 4)
input_data = (10 * np.random.random(input_data_shape))
input_data = input_data.astype(input_dtype)
expected_output_dtype = input_dtype
# test in functional API
inp = Input(shape=input_shape[1:], dtype=input_dtype)
outp = layer_func(inp, **kwargs)
assert K.dtype(outp) == expected_output_dtype
# check with the functional API
model = Model(inp, outp)
actual_output = model.predict(input_data)
# test serialization, weight setting at model level
model_config = model.get_config()
recovered_model = model.__class__.from_config(model_config)
if model.weights:
weights = model.get_weights()
recovered_model.set_weights(weights)
_output = recovered_model.predict(input_data)
assert_allclose(_output, actual_output, rtol=1e-3)
# for further checks in the caller function
return actual_output
_PARAMS = ["use_icnr_init", "use_convaware_init", "use_reflect_padding"]
_VALUES = list(product([True, False], repeat=len(_PARAMS)))
_IDS = ["{}[{}]".format("|".join([_PARAMS[idx] for idx, b in enumerate(v) if b]),
get_backend().upper()) for v in _VALUES]
@pytest.mark.parametrize(_PARAMS, _VALUES, ids=_IDS)
def test_blocks(use_icnr_init, use_convaware_init, use_reflect_padding):
""" Test for all blocks contained within the NNBlocks Class """
config = dict(icnr_init=use_icnr_init,
conv_aware_init=use_convaware_init,
reflect_padding=use_reflect_padding)
nn_blocks.set_config(config)
block_test(nn_blocks.Conv2DOutput(64, 3), input_shape=(2, 8, 8, 32))
block_test(nn_blocks.Conv2DBlock(64), input_shape=(2, 8, 8, 32))
block_test(nn_blocks.SeparableConv2DBlock(64), input_shape=(2, 8, 8, 32))
block_test(nn_blocks.UpscaleBlock(64), input_shape=(2, 4, 4, 128))
block_test(nn_blocks.Upscale2xBlock(64, fast=True), input_shape=(2, 4, 4, 128))
block_test(nn_blocks.Upscale2xBlock(64, fast=False), input_shape=(2, 4, 4, 128))
block_test(nn_blocks.ResidualBlock(64), input_shape=(2, 4, 4, 64))