1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00
faceswap/lib/plaidml_utils.py
torzdf 582c2ce40c Add Flip Loss Function
- Add Flip for AMD and TF
    - Split Perceptual Loss functions to own modules
    - Fix allowed input shape for models
    - Allow GUI tooltip to display at higher width
2022-07-07 01:02:11 +01:00

59 lines
1.7 KiB
Python

#!/usr/bin/env python3
""" PlaidML helper Utilities """
from typing import Optional
import plaidml
def pad(data: plaidml.tile.Value,
paddings,
mode: str = "CONSTANT",
name: Optional[str] = None, # pylint:disable=unused-argument
constant_value: int = 0) -> plaidml.tile.Value:
""" PlaidML Pad
Notes
-----
Currently only Reflect padding is supported.
Parameters
----------
data :class:`plaidm.tile.Value`
The tensor to pad
mode: str, optional
The padding mode to use. Default: `"CONSTANT"`
name: str, optional
The name for the operation. Unused but kept for consistency with tf.pad. Default: ``None``
constant_value: int, optional
The value to pad the Tensor with. Default: `0`
Returns
-------
:class:`plaidm.tile.Value`
The padded tensor
"""
# TODO: use / implement other padding method when required
# CONSTANT -> SpatialPadding ? | Doesn't support first and last axis +
# no support for constant_value
# SYMMETRIC -> Requires implement ?
if mode.upper() != "REFLECT":
raise NotImplementedError("pad only supports mode == 'REFLECT'")
if constant_value != 0:
raise NotImplementedError("pad does not support constant_value != 0")
return plaidml.op.reflection_padding(data, paddings)
def is_plaidml_error(error: Exception) -> bool:
""" Test whether the given exception is a plaidml Exception.
Parameters
----------
error: :class:`Exception`
The generated error
Returns
-------
bool
``True`` if the given error has been generated from plaidML otherwise ``False``
"""
return isinstance(error, plaidml.exceptions.PlaidMLError)