1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/plugin_loader.py
torzdf 7f53911453
Logging (#541)
* Convert prints to logger. Further logging improvements. Tidy  up

* Fix system verbosity. Allow SystemExit

* Fix reload extract bug

* Child Traceback handling

* Safer shutdown procedure

* Add shutdown event to queue manager

* landmarks_as_xy > property. GUI notes + linting. Aligner bugfix

* fix FaceFilter. Enable nFilter when no Filter is supplied

* Fix blurry face filter

* Continue on IO error. Better error handling

* Explicitly print stack trace tocrash log

* Windows Multiprocessing bugfix

* Add git info and conda version to crash log

* Windows/Anaconda mp bugfix

* Logging fixes for training
2018-12-04 13:31:49 +00:00

75 lines
2.7 KiB
Python

#!/usr/bin/env python3
""" Plugin loader for extract, training and model tasks """
import logging
import os
from importlib import import_module
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class PluginLoader():
""" Plugin loader for extract, training and model tasks """
@staticmethod
def get_detector(name):
""" Return requested detector plugin """
return PluginLoader._import("extract.detect", name)
@staticmethod
def get_aligner(name):
""" Return requested detector plugin """
return PluginLoader._import("extract.align", name)
@staticmethod
def get_converter(name):
""" Return requested converter plugin """
return PluginLoader._import("Convert", "Convert_{0}".format(name))
@staticmethod
def get_model(name):
""" Return requested model plugin """
return PluginLoader._import("Model", "Model_{0}".format(name))
@staticmethod
def get_trainer(name):
""" Return requested trainer plugin """
return PluginLoader._import("Trainer", "Model_{0}".format(name))
@staticmethod
def _import(attr, name):
""" Import the plugin's module """
ttl = attr.split(".")[-1].title()
logger.info("Loading %s from %s plugin...", ttl, name.title())
attr = "model" if attr == "Trainer" else attr.lower()
mod = ".".join(("plugins", attr, name))
module = import_module(mod)
return getattr(module, ttl)
@staticmethod
def get_available_models():
""" Return a list of available models """
models = ()
modelpath = os.path.join(os.path.dirname(__file__), "model")
for modeldir in next(os.walk(modelpath))[1]:
if modeldir[0:6].lower() == 'model_':
models += (modeldir[6:],)
return models
@staticmethod
def get_available_extractors(extractor_type):
""" Return a list of available models """
extractpath = os.path.join(os.path.dirname(__file__),
"extract",
extractor_type)
extractors = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(extractpath)
if not item.name.startswith("_")
and item.name.endswith(".py")
and item.name != "manual.py")
return extractors
@staticmethod
def get_default_model():
""" Return the default model """
models = PluginLoader.get_available_models()
return 'Original' if 'Original' in models else models[0]