1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00
faceswap/lib/utils.py
torzdf d8557c1970
Faceswap 2.0 (#1045)
* Core Updates
    - Remove lib.utils.keras_backend_quiet and replace with get_backend() where relevant
    - Document lib.gpu_stats and lib.sys_info
    - Remove call to GPUStats.is_plaidml from convert and replace with get_backend()
    - lib.gui.menu - typofix

* Update Dependencies
Bump Tensorflow Version Check

* Port extraction to tf2

* Add custom import finder for loading Keras or tf.keras depending on backend

* Add `tensorflow` to KerasFinder search path

* Basic TF2 training running

* model.initializers - docstring fix

* Fix and pass tests for tf2

* Replace Keras backend tests with faceswap backend tests

* Initial optimizers update

* Monkey patch tf.keras optimizer

* Remove custom Adam Optimizers and Memory Saving Gradients

* Remove multi-gpu option. Add Distribution to cli

* plugins.train.model._base: Add Mirror, Central and Default distribution strategies

* Update tensorboard kwargs for tf2

* Penalized Loss - Fix for TF2 and AMD

* Fix syntax for tf2.1

* requirements typo fix

* Explicit None for clipnorm if using a distribution strategy

* Fix penalized loss for distribution strategies

* Update Dlight

* typo fix

* Pin to TF2.2

* setup.py - Install tensorflow from pip if not available in Conda

* Add reduction options and set default for mirrored distribution strategy

* Explicitly use default strategy rather than nullcontext

* lib.model.backup_restore documentation

* Remove mirrored strategy reduction method and default based on OS

* Initial restructure - training

* Remove PingPong
Start model.base refactor

* Model saving and resuming enabled

* More tidying up of model.base

* Enable backup and snapshotting

* Re-enable state file
Remove loss names from state file
Fix print loss function
Set snapshot iterations correctly

* Revert original model to Keras Model structure rather than custom layer
Output full model and sub model summary
Change NNBlocks to callables rather than custom keras layers

* Apply custom Conv2D layer

* Finalize NNBlock restructure
Update Dfaker blocks

* Fix reloading model under a different distribution strategy

* Pass command line arguments through to trainer

* Remove training_opts from model and reference params directly

* Tidy up model __init__

* Re-enable tensorboard logging
Suppress "Model Not Compiled" warning

* Fix timelapse

* lib.model.nnblocks - Bugfix residual block
Port dfaker
bugfix original

* dfl-h128 ported

* DFL SAE ported

* IAE Ported

* dlight ported

* port lightweight

* realface ported

* unbalanced ported

* villain ported

* lib.cli.args - Update Batchsize + move allow_growth to config

* Remove output shape definition
Get image sizes per side rather than globally

* Strip mask input from encoder

* Fix learn mask and output learned mask to preview

* Trigger Allow Growth prior to setting strategy

* Fix GUI Graphing

* GUI - Display batchsize correctly + fix training graphs

* Fix penalized loss

* Enable mixed precision training

* Update analysis displayed batch to match input

* Penalized Loss - Multi-GPU Fix

* Fix all losses for TF2

* Fix Reflect Padding

* Allow different input size for each side of the model

* Fix conv-aware initialization on reload

* Switch allow_growth order

* Move mixed_precision to cli

* Remove distrubution strategies

* Compile penalized loss sub-function into LossContainer

* Bump default save interval to 250
Generate preview on first iteration but don't save
Fix iterations to start at 1 instead of 0
Remove training deprecation warnings
Bump some scripts.train loglevels

* Add ability to refresh preview on demand on pop-up window

* Enable refresh of training preview from GUI

* Fix Convert
Debug logging in Initializers

* Fix Preview Tool

* Update Legacy TF1 weights to TF2
Catch stats error on loading stats with missing logs

* lib.gui.popup_configure - Make more responsive + document

* Multiple Outputs supported in trainer
Original Model - Mask output bugfix

* Make universal inference model for convert
Remove scaling from penalized mask loss (now handled at input to y_true)

* Fix inference model to work properly with all models

* Fix multi-scale output for convert

* Fix clipnorm issue with distribution strategies
Edit error message on OOM

* Update plaidml losses

* Add missing file

* Disable gmsd loss for plaidnl

* PlaidML - Basic training working

* clipnorm rewriting for mixed-precision

* Inference model creation bugfixes

* Remove debug code

* Bugfix: Default clipnorm to 1.0

* Remove all mask inputs from training code

* Remove mask inputs from convert

* GUI - Analysis Tab - Docstrings

* Fix rate in totals row

* lib.gui - Only update display pages if they have focus

* Save the model on first iteration

* plaidml - Fix SSIM loss with penalized loss

* tools.alignments - Remove manual and fix jobs

* GUI - Remove case formatting on help text

* gui MultiSelect custom widget - Set default values on init

* vgg_face2 - Move to plugins.extract.recognition and use plugins._base base class
cli - Add global GPU Exclude Option
tools.sort - Use global GPU Exlude option for backend
lib.model.session - Exclude all GPUs when running in CPU mode
lib.cli.launcher - Set backend to CPU mode when all GPUs excluded

* Cascade excluded devices to GPU Stats

* Explicit GPU selection for Train and Convert

* Reduce Tensorflow Min GPU Multiprocessor Count to 4

* remove compat.v1 code from extract

* Force TF to skip mixed precision compatibility check if GPUs have been filtered

* Add notes to config for non-working AMD losses

* Rasie error if forcing extract to CPU mode

* Fix loading of legace dfl-sae weights + dfl-sae typo fix

* Remove unused requirements
Update sphinx requirements
Fix broken rst file locations

* docs: lib.gui.display

* clipnorm amd condition check

* documentation - gui.display_analysis

* Documentation - gui.popup_configure

* Documentation - lib.logger

* Documentation - lib.model.initializers

* Documentation - lib.model.layers

* Documentation - lib.model.losses

* Documentation - lib.model.nn_blocks

* Documetation - lib.model.normalization

* Documentation - lib.model.session

* Documentation - lib.plaidml_stats

* Documentation: lib.training_data

* Documentation: lib.utils

* Documentation: plugins.train.model._base

* GUI Stats: prevent stats from using GPU

* Documentation - Original Model

* Documentation: plugins.model.trainer._base

* linting

* unit tests: initializers + losses

* unit tests: nn_blocks

* bugfix - Exclude gpu devices in train, not include

* Enable Exclude-Gpus in Extract

* Enable exclude gpus in tools

* Disallow multiple plugin types in a single model folder

* Automatically add exclude_gpus argument in for cpu backends

* Cpu backend fixes

* Relax optimizer test threshold

* Default Train settings - Set mask to Extended

* Update Extractor cli help text
Update to Python 3.8

* Fix FAN to run on CPU

* lib.plaidml_tools - typofix

* Linux installer - check for curl

* linux installer - typo fix
2020-08-12 10:36:41 +01:00

656 lines
24 KiB
Python

#!/usr/bin python3
""" Utilities available across all scripts """
import importlib
import json
import logging
import os
import sys
import urllib
import warnings
import zipfile
from pathlib import Path
from re import finditer
from multiprocessing import current_process
from socket import timeout as socket_timeout, error as socket_error
from tqdm import tqdm
# Global variables
_image_extensions = [ # pylint:disable=invalid-name
".bmp", ".jpeg", ".jpg", ".png", ".tif", ".tiff"]
_video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"]
class _Backend(): # pylint:disable=too-few-public-methods
""" Return the backend from config/.faceswap of from the `FACESWAP_BACKEND` Environment
Variable.
If file doesn't exist and a variable hasn't been set, create the config file. """
def __init__(self):
self._backends = {"1": "amd", "2": "cpu", "3": "nvidia"}
self._config_file = self._get_config_file()
self.backend = self._get_backend()
@classmethod
def _get_config_file(cls):
""" Obtain the location of the main Faceswap configuration file.
Returns
-------
str
The path to the Faceswap configuration file
"""
pypath = os.path.dirname(os.path.realpath(sys.argv[0]))
config_file = os.path.join(pypath, "config", ".faceswap")
return config_file
def _get_backend(self):
""" Return the backend from either the `FACESWAP_BACKEND` Environment Variable or from
the :file:`config/.faceswap` configuration file. If neither of these exist, prompt the user
to select a backend.
Returns
-------
str
The backend configuration in use by Faceswap
"""
# Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ:
fs_backend = os.environ["FACESWAP_BACKEND"].lower()
print("Setting Faceswap backend from environment variable to "
"{}".format(fs_backend.upper()))
return fs_backend
# Intercept for sphinx docs build
if sys.argv[0].endswith("sphinx-build"):
return "nvidia"
if not os.path.isfile(self._config_file):
self._configure_backend()
while True:
try:
with open(self._config_file, "r") as cnf:
config = json.load(cnf)
break
except json.decoder.JSONDecodeError:
self._configure_backend()
continue
fs_backend = config.get("backend", None)
if fs_backend is None or fs_backend.lower() not in self._backends.values():
fs_backend = self._configure_backend()
if current_process().name == "MainProcess":
print("Setting Faceswap backend to {}".format(fs_backend.upper()))
return fs_backend.lower()
def _configure_backend(self):
""" Get user input to select the backend that Faceswap should use.
Returns
-------
str
The backend configuration in use by Faceswap
"""
print("First time configuration. Please select the required backend")
while True:
selection = input("1: AMD, 2: CPU, 3: NVIDIA: ")
if selection not in ("1", "2", "3"):
print("'{}' is not a valid selection. Please try again".format(selection))
continue
break
fs_backend = self._backends[selection].lower()
config = {"backend": fs_backend}
with open(self._config_file, "w") as cnf:
json.dump(config, cnf)
print("Faceswap config written to: {}".format(self._config_file))
return fs_backend
_FS_BACKEND = _Backend().backend
def get_backend():
""" Get the backend that Faceswap is currently configured to use.
Returns
-------
str
The backend configuration in use by Faceswap
"""
return _FS_BACKEND
def set_backend(backend):
""" Override the configured backend with the given backend.
Parameters
----------
backend: ["amd", "cpu", "nvidia"]
The backend to set faceswap to
"""
global _FS_BACKEND # pylint:disable=global-statement
_FS_BACKEND = backend.lower()
def get_folder(path, make_folder=True):
""" Return a path to a folder, creating it if it doesn't exist
Parameters
----------
path: str
The path to the folder to obtain
make_folder: bool, optional
``True`` if the folder should be created if it does not already exist, ``False`` if the
folder should not be created
Returns
-------
:class:`pathlib.Path` or `None`
The path to the requested folder. If `make_folder` is set to ``False`` and the requested
path does not exist, then ``None`` is returned
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("Requested path: '%s'", path)
output_dir = Path(path)
if not make_folder and not output_dir.exists():
logger.debug("%s does not exist", path)
return None
output_dir.mkdir(parents=True, exist_ok=True)
logger.debug("Returning: '%s'", output_dir)
return output_dir
def get_image_paths(directory):
""" Obtain a list of full paths that reside within a folder.
Parameters
----------
directory: str
The folder that contains the images to be returned
Returns
-------
list
The list of full paths to the images contained within the given folder
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
image_extensions = _image_extensions
dir_contents = list()
if not os.path.exists(directory):
logger.debug("Creating folder: '%s'", directory)
directory = get_folder(directory)
dir_scanned = sorted(os.scandir(directory), key=lambda x: x.name)
logger.debug("Scanned Folder contains %s files", len(dir_scanned))
logger.trace("Scanned Folder Contents: %s", dir_scanned)
for chkfile in dir_scanned:
if any([chkfile.name.lower().endswith(ext)
for ext in image_extensions]):
logger.trace("Adding '%s' to image list", chkfile.path)
dir_contents.append(chkfile.path)
logger.debug("Returning %s images", len(dir_contents))
return dir_contents
def convert_to_secs(*args):
""" Convert a time to seconds.
Parameters
----------
args: tuple
2 or 3 ints. If 2 ints are supplied, then (`minutes`, `seconds`) is implied. If 3 ints are
supplied then (`hours`, `minutes`, `seconds`) is implied.
Returns
-------
int
The given time converted to seconds
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("from time: %s", args)
retval = 0.0
if len(args) == 1:
retval = float(args[0])
elif len(args) == 2:
retval = 60 * float(args[0]) + float(args[1])
elif len(args) == 3:
retval = 3600 * float(args[0]) + 60 * float(args[1]) + float(args[2])
logger.debug("to secs: %s", retval)
return retval
def full_path_split(path):
""" Split a full path to a location into all of it's separate components.
Parameters
----------
path: str
The full path to be split
Returns
-------
list
The full path split into a separate item for each part
Example
-------
>>> path = "/foo/baz/bar"
>>> full_path_split(path)
>>> ["foo", "baz", "bar"]
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
allparts = list()
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
allparts.insert(0, parts[0])
break
if parts[1] == path: # sentinel for relative paths
allparts.insert(0, parts[1])
break
path = parts[0]
allparts.insert(0, parts[1])
logger.trace("path: %s, allparts: %s", path, allparts)
return allparts
def set_system_verbosity(log_level):
""" Set the verbosity level of tensorflow and suppresses future and deprecation warnings from
any modules
Parameters
----------
log_level: str
The requested Faceswap log level
References
----------
https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
Can be set to:
0: all logs shown. 1: filter out INFO logs. 2: filter out WARNING logs. 3: filter out ERROR
logs.
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
from lib.logger import get_loglevel # pylint:disable=import-outside-toplevel
numeric_level = get_loglevel(log_level)
log_level = "2" if numeric_level > 15 else "0"
logger.debug("System Verbosity level: %s", log_level)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = log_level
if log_level != '0':
for warncat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.simplefilter(action='ignore', category=warncat)
def deprecation_warning(function, additional_info=None):
""" Log at warning level that a function will be removed in a future update.
Parameters
----------
function: str
The function that will be deprecated.
additional_info: str, optional
Any additional information to display with the deprecation message. Default: ``None``
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("func_name: %s, additional_info: %s", function, additional_info)
msg = "{} has been deprecated and will be removed from a future update.".format(function)
if additional_info is not None:
msg += " {}".format(additional_info)
logger.warning(msg)
def camel_case_split(identifier):
""" Split a camel case name
Parameters
----------
identifier: str
The camel case text to be split
Returns
-------
list
A list of the given identifier split into it's constituent parts
References
----------
https://stackoverflow.com/questions/29916065
"""
matches = finditer(
".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)",
identifier)
return [m.group(0) for m in matches]
def safe_shutdown(got_error=False):
""" Close all tracked queues and threads in event of crash or on shut down.
Parameters
----------
got_error: bool, optional
``True`` if this function is being called as the result of raised error, otherwise
``False``. Default: ``False``
"""
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
logger.debug("Safely shutting down")
from lib.queue_manager import queue_manager # pylint:disable=import-outside-toplevel
queue_manager.terminate_queues()
logger.debug("Cleanup complete. Shutting down queue manager and exiting")
sys.exit(1 if got_error else 0)
class FaceswapError(Exception):
""" Faceswap Error for handling specific errors with useful information.
Raises
------
FaceswapError
on a captured error
"""
pass # pylint:disable=unnecessary-pass
class GetModel(): # Pylint:disable=too-few-public-methods
""" Check for models in their cache path.
If available, return the path, if not available, get, unzip and install model
Parameters
----------
model_filename: str or list
The name of the model to be loaded (see notes below)
cache_dir: str
The model cache folder of the current plugin calling this class. IE: The folder that holds
the model to be loaded.
git_model_id: int
The second digit in the github tag that identifies this model. See
https://github.com/deepfakes-models/faceswap-models for more information
Notes
------
Models must have a certain naming convention: `<model_name>_v<version_number>.<extension>`
(eg: `s3fd_v1.pb`).
Multiple models can exist within the model_filename. They should be passed as a list and follow
the same naming convention as above. Any differences in filename should occur AFTER the version
number: `<model_name>_v<version_number><differentiating_information>.<extension>` (eg:
`["mtcnn_det_v1.1.py", "mtcnn_det_v1.2.py", "mtcnn_det_v1.3.py"]`, `["resnet_ssd_v1.caffemodel"
,"resnet_ssd_v1.prototext"]`
"""
def __init__(self, model_filename, cache_dir, git_model_id):
self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list):
model_filename = [model_filename]
self._model_filename = model_filename
self._cache_dir = cache_dir
self._git_model_id = git_model_id
self._url_base = "https://github.com/deepfakes-models/faceswap-models/releases/download"
self._chunk_size = 1024 # Chunk size for downloading and unzipping
self._retries = 6
self._get()
@property
def _model_full_name(self):
""" str: The full model name from the filename(s). """
common_prefix = os.path.commonprefix(self._model_filename)
retval = os.path.splitext(common_prefix)[0]
self.logger.trace(retval)
return retval
@property
def _model_name(self):
""" str: The model name from the model's full name. """
retval = self._model_full_name[:self._model_full_name.rfind("_")]
self.logger.trace(retval)
return retval
@property
def _model_version(self):
""" int: The model's version number from the model full name. """
retval = int(self._model_full_name[self._model_full_name.rfind("_") + 2:])
self.logger.trace(retval)
return retval
@property
def model_path(self):
""" str: The model path(s) in the cache folder. """
retval = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
retval = retval[0] if len(retval) == 1 else retval
self.logger.trace(retval)
return retval
@property
def _model_zip_path(self):
""" str: The full path to downloaded zip file. """
retval = os.path.join(self._cache_dir, "{}.zip".format(self._model_full_name))
self.logger.trace(retval)
return retval
@property
def _model_exists(self):
""" bool: ``True`` if the model exists in the cache folder otherwise ``False``. """
if isinstance(self.model_path, list):
retval = all(os.path.exists(pth) for pth in self.model_path)
else:
retval = os.path.exists(self.model_path)
self.logger.trace(retval)
return retval
@property
def _plugin_section(self):
""" str: The plugin section from the config_dir """
path = os.path.normpath(self._cache_dir)
split = path.split(os.sep)
retval = split[split.index("plugins") + 1]
self.logger.trace(retval)
return retval
@property
def _url_section(self):
""" int: The section ID in github for this plugin type. """
sections = dict(extract=1, train=2, convert=3)
retval = sections[self._plugin_section]
self.logger.trace(retval)
return retval
@property
def _url_download(self):
""" strL Base download URL for models. """
tag = "v{}.{}.{}".format(self._url_section, self._git_model_id, self._model_version)
retval = "{}/{}/{}.zip".format(self._url_base, tag, self._model_full_name)
self.logger.trace("Download url: %s", retval)
return retval
@property
def _url_partial_size(self):
""" float: How many bytes have already been downloaded. """
zip_file = self._model_zip_path
retval = os.path.getsize(zip_file) if os.path.exists(zip_file) else 0
self.logger.trace(retval)
return retval
def _get(self):
""" Check the model exists, if not, download the model, unzip it and place it in the
model's cache folder. """
if self._model_exists:
self.logger.debug("Model exists: %s", self.model_path)
return
self._download_model()
self._unzip_model()
os.remove(self._model_zip_path)
def _download_model(self):
""" Download the model zip from github to the cache folder. """
self.logger.info("Downloading model: '%s' from: %s", self._model_name, self._url_download)
for attempt in range(self._retries):
try:
downloaded_size = self._url_partial_size
req = urllib.request.Request(self._url_download)
if downloaded_size != 0:
req.add_header("Range", "bytes={}-".format(downloaded_size))
response = urllib.request.urlopen(req, timeout=10)
self.logger.debug("header info: {%s}", response.info())
self.logger.debug("Return Code: %s", response.getcode())
self._write_zipfile(response, downloaded_size)
break
except (socket_error, socket_timeout,
urllib.error.HTTPError, urllib.error.URLError) as err:
if attempt + 1 < self._retries:
self.logger.warning("Error downloading model (%s). Retrying %s of %s...",
str(err), attempt + 2, self._retries)
else:
self.logger.error("Failed to download model. Exiting. (Error: '%s', URL: "
"'%s')", str(err), self._url_download)
self.logger.info("You can try running again to resume the download.")
self.logger.info("Alternatively, you can manually download the model from: %s "
"and unzip the contents to: %s",
self._url_download, self._cache_dir)
sys.exit(1)
def _write_zipfile(self, response, downloaded_size):
""" Write the model zip file to disk.
Parameters
----------
response: :class:`urllib.request.urlopen`
The response from the model download task
downloaded_size: int
The amount of bytes downloaded so far
"""
length = int(response.getheader("content-length")) + downloaded_size
if length == downloaded_size:
self.logger.info("Zip already exists. Skipping download")
return
write_type = "wb" if downloaded_size == 0 else "ab"
with open(self._model_zip_path, write_type) as out_file:
pbar = tqdm(desc="Downloading",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
if downloaded_size != 0:
pbar.update(downloaded_size)
while True:
buffer = response.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
pbar.close()
def _unzip_model(self):
""" Unzip the model file to the cache folder """
self.logger.info("Extracting: '%s'", self._model_name)
try:
zip_file = zipfile.ZipFile(self._model_zip_path, "r")
self._write_model(zip_file)
except Exception as err: # pylint:disable=broad-except
self.logger.error("Unable to extract model file: %s", str(err))
sys.exit(1)
def _write_model(self, zip_file):
""" Extract files from zip file and write, with progress bar.
Parameters
----------
zip_file: str
The downloaded model zip file
"""
length = sum(f.file_size for f in zip_file.infolist())
fnames = zip_file.namelist()
self.logger.debug("Zipfile: Filenames: %s, Total Size: %s", fnames, length)
pbar = tqdm(desc="Decompressing",
unit="B",
total=length,
unit_scale=True,
unit_divisor=1024)
for fname in fnames:
out_fname = os.path.join(self._cache_dir, fname)
self.logger.debug("Extracting from: '%s' to '%s'", self._model_zip_path, out_fname)
zipped = zip_file.open(fname)
with open(out_fname, "wb") as out_file:
while True:
buffer = zipped.read(self._chunk_size)
if not buffer:
break
pbar.update(len(buffer))
out_file.write(buffer)
zip_file.close()
pbar.close()
class KerasFinder(importlib.abc.MetaPathFinder):
""" Importlib Abstract Base Class for intercepting the import of Keras and returning either
Keras (AMD backend) or tensorflow.keras (any other backend).
The Importlib documentation is sparse at best, and real world examples are pretty much
non-existent. Coupled with this, the import ``tensorflow.keras`` does not resolve so we need
to split out to the actual location of Keras within ``tensorflow_core``. This method works, but
it relies on hard coded paths, and is likely to not be the most robust.
A custom loader is not used, as we can use the standard loader once we have returned the
correct spec.
"""
def __init__(self):
self._logger = logging.getLogger(__name__)
self._backend = get_backend()
self._tf_keras_locations = [["tensorflow_core", "python", "keras", "api", "_v2"],
["tensorflow", "python", "keras", "api", "_v2"]]
def find_spec(self, fullname, path, target=None): # pylint:disable=unused-argument
""" Obtain the spec for either keras or tensorflow.keras depending on the backend in use.
If keras is not passed in as part of the :attr:`fullname` or the path is not ``None``
(i.e this is a dependency import) then this returns ``None`` to use the standard import
library.
Parameters
----------
fullname: str
The absolute name of the module to be imported
path: str
The search path for the module
target: module object, optional
Inherited from parent but unused
Returns
-------
:class:`importlib.ModuleSpec`
The spec for the Keras module to be imported
"""
prefix = fullname.split(".")[0]
suffix = fullname.split(".")[-1]
if prefix != "keras" or path is not None:
return None
self._logger.debug("Importing '%s' as keras for backend: '%s'",
"keras" if self._backend == "amd" else "tf.keras", self._backend)
path = sys.path if path is None else path
for entry in path:
locations = ([os.path.join(entry, *location)
for location in self._tf_keras_locations]
if self._backend != "amd" else [entry])
for location in locations:
self._logger.debug("Scanning: '%s' for '%s'", location, suffix)
if os.path.isdir(os.path.join(location, suffix)):
filename = os.path.join(location, suffix, "__init__.py")
submodule_locations = [os.path.join(location, suffix)]
else:
filename = os.path.join(location, suffix + ".py")
submodule_locations = None
if not os.path.exists(filename):
continue
retval = importlib.util.spec_from_file_location(
fullname,
filename,
submodule_search_locations=submodule_locations)
self._logger.debug("Found spec: %s", retval)
return retval
self._logger.debug("Spec not found for '%s'. Falling back to default import", fullname)
return None