1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-08 20:13:52 -04:00
faceswap/lib/serializer.py
torzdf 70ee125283 Serialize masks to alignments file
- Add new serializers (npy + compressed)
- Remove Serializer option from cli
- Revert get_aligned call in scripts/extract
- Default alignments to compressed
- Size masks to 128px and compress
- Remove mask thresholding/blur from generation code
- Add Mask class to lib/faces_detect
- Revert debug landmarks to aligned face
- Revert non-extraction code to staging version
2019-10-13 22:50:28 +00:00

349 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Library for serializing python objects to and from various different serializer formats
"""
import json
import logging
import os
import pickle
import zlib
from io import BytesIO
import numpy as np
from lib.utils import FaceswapError
try:
import yaml
except ImportError:
yaml = None
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class Serializer():
""" A convenience class for various serializers.
This class should not be called directly as it acts as the parent for various serializers.
All serializers should be called from :func:`get_serializer` or
:func:`get_serializer_from_filename`
Example
-------
>>> from lib.serializer import get_serializer
>>> serializer = get_serializer('json')
>>> json_file = '/path/to/json/file.json'
>>> data = serializer.load(json_file)
>>> serializer.save(json_file, data)
"""
def __init__(self):
self._file_extension = None
self._write_option = "wb"
self._read_option = "rb"
@property
def file_extension(self):
""" str: The file extension of the serializer """
return self._file_extension
def save(self, filename, data):
""" Serialize data and save to a file
Parameters
----------
filename: str
The path to where the serialized file should be saved
data: varies
The data that is to be serialized to file
Example
------
>>> serializer = get_serializer('json')
>>> data ['foo', 'bar']
>>> json_file = '/path/to/json/file.json'
>>> serializer.save(json_file, data)
"""
logger.debug("filename: %s, data type: %s", filename, type(data))
filename = self._check_extension(filename)
try:
with open(filename, self._write_option) as s_file:
s_file.write(self.marshal(data))
except IOError as err:
msg = "Error writing to '{}': {}".format(filename, err.strerror)
raise FaceswapError(msg) from err
def _check_extension(self, filename):
""" Check the filename has an extension. If not add the correct one for the serializer """
extension = os.path.splitext(filename)[1]
retval = filename if extension else "{}.{}".format(filename, self.file_extension)
logger.debug("Original filename: '%s', final filename: '%s'", filename, retval)
return retval
def load(self, filename):
""" Load data from an existing serialized file
Parameters
----------
filename: str
The path to the serialized file
Returns
----------
data: varies
The data in a python object format
Example
------
>>> serializer = get_serializer('json')
>>> json_file = '/path/to/json/file.json'
>>> data = serializer.load(json_file)
"""
logger.debug("filename: %s", filename)
try:
with open(filename, self._read_option) as s_file:
data = s_file.read()
logger.debug("stored data type: %s", type(data))
retval = self.unmarshal(data)
except IOError as err:
msg = "Error reading from '{}': {}".format(filename, err.strerror)
raise FaceswapError(msg) from err
logger.debug("data type: %s", type(retval))
return retval
def marshal(self, data):
""" Serialize an object
Parameters
----------
data: varies
The data that is to be serialized
Returns
-------
data: varies
The data in a the serialized data format
Example
------
>>> serializer = get_serializer('json')
>>> data ['foo', 'bar']
>>> json_data = serializer.marshal(data)
"""
logger.debug("data type: %s", type(data))
try:
retval = self._marshal(data)
except Exception as err:
msg = "Error serializing data for type {}: {}".format(type(data), str(err))
raise FaceswapError(msg) from err
logger.debug("returned data type: %s", type(retval))
return retval
def unmarshal(self, serialized_data):
""" Unserialize data to its original object type
Parameters
----------
serialized_data: varies
Data in serializer format that is to be unmarshalled to its original object
Returns
-------
data: varies
The data in a python object format
Example
------
>>> serializer = get_serializer('json')
>>> json_data = <json object>
>>> data = serializer.unmarshal(json_data)
"""
logger.debug("data type: %s", type(serialized_data))
try:
retval = self._unmarshal(serialized_data)
except Exception as err:
msg = "Error unserializing data for type {}: {}".format(type(serialized_data),
str(err))
raise FaceswapError(msg) from err
logger.debug("returned data type: %s", type(retval))
return retval
@classmethod
def _marshal(cls, data):
""" Override for serializer specific marshalling """
raise NotImplementedError()
@classmethod
def _unmarshal(cls, data):
""" Override for serializer specific unmarshalling """
raise NotImplementedError()
class _YAMLSerializer(Serializer):
""" YAML Serializer """
def __init__(self):
super().__init__()
self._file_extension = "yml"
@classmethod
def _marshal(cls, data):
return yaml.dump(data, default_flow_style=False).encode("utf-8")
@classmethod
def _unmarshal(cls, data):
return yaml.load(data.decode("utf-8"), Loader=yaml.FullLoader)
class _JSONSerializer(Serializer):
""" JSON Serializer """
def __init__(self):
super().__init__()
self._file_extension = "json"
@classmethod
def _marshal(cls, data):
return json.dumps(data, indent=2).encode("utf-8")
@classmethod
def _unmarshal(cls, data):
return json.loads(data.decode("utf-8"))
class _PickleSerializer(Serializer):
""" Pickle Serializer """
def __init__(self):
super().__init__()
self._file_extension = "pickle"
@classmethod
def _marshal(cls, data):
return pickle.dumps(data)
@classmethod
def _unmarshal(cls, data):
return pickle.loads(data)
class _NPYSerializer(Serializer): # pylint:disable=abstract-method
""" NPY Serializer """
def __init__(self):
super().__init__()
self._file_extension = "npy"
self._bytes = BytesIO()
def _marshal(self, data):
""" NPY Marshal to bytesIO so standard bytes writer can write out """
b_handler = BytesIO()
np.save(b_handler, data)
b_handler.seek(0)
return b_handler.read()
def _unmarshal(self, data):
""" NPY Unmarshal to bytesIO so we can use numpy loader """
b_handler = BytesIO(data)
retval = np.load(b_handler)
del b_handler
if retval.dtype == "object":
retval = retval[()]
return retval
class _CompressedSerializer(Serializer):
""" A compressed pickle serializer for Faceswap """
def __init__(self):
super().__init__()
self._file_extension = "fsc"
self._child = get_serializer("pickle")
def _marshal(self, data):
""" Pickle and compress data """
data = self._child._marshal(data) # pylint: disable=protected-access
return zlib.compress(data)
def _unmarshal(self, data):
""" Decompress and unpicke data """
data = zlib.decompress(data)
return self._child._unmarshal(data) # pylint: disable=protected-access
def get_serializer(serializer):
""" Obtain a serializer object
Parameters
----------
serializer: {'json', 'pickle', yaml', 'npy', 'compressed'}
The required serializer format
Returns
-------
serializer: :class:`Serializer`
A serializer object for handling the requested data format
Example
-------
>>> serializer = get_serializer('json')
"""
if serializer.lower() == "npy":
retval = _NPYSerializer()
elif serializer.lower() == "compressed":
retval = _CompressedSerializer()
elif serializer.lower() == "json":
retval = _JSONSerializer()
elif serializer.lower() == "pickle":
retval = _PickleSerializer()
elif serializer.lower() == "yaml" and yaml is not None:
retval = _YAMLSerializer()
elif serializer.lower() == "yaml" and yaml is None:
logger.warning("You must have PyYAML installed to use YAML as the serializer."
"Switching to JSON as the serializer.")
retval = _JSONSerializer
else:
logger.warning("Unrecognized serializer: '%s'. Returning json serializer", serializer)
logger.debug(retval)
return retval
def get_serializer_from_filename(filename):
""" Obtain a serializer object from a filename
Parameters
----------
filename: str
Filename to determine the serializer type from
Returns
-------
serializer: :class:`Serializer`
A serializer object for handling the requested data format
Example
-------
>>> filename = '/path/to/json/file.json'
>>> serializer = get_serializer_from_filename(filename)
"""
logger.debug("filename: '%s'", filename)
extension = os.path.splitext(filename)[1].lower()
logger.debug("extension: '%s'", extension)
if extension == ".json":
retval = _JSONSerializer()
elif extension == ".p":
retval = _PickleSerializer()
elif extension == ".npy":
retval = _NPYSerializer()
elif extension == ".fsc":
retval = _CompressedSerializer()
elif extension in (".yaml", ".yml") and yaml is not None:
retval = _YAMLSerializer()
elif extension in (".yaml", ".yml") and yaml is None:
logger.warning("You must have PyYAML installed to use YAML as the serializer.\n"
"Switching to JSON as the serializer.")
retval = _JSONSerializer()
else:
logger.warning("Unrecognized extension: '%s'. Returning json serializer", extension)
retval = _JSONSerializer()
logger.debug(retval)
return retval