1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/plugin_loader.py
torzdf ca63242996
Extraction - Speed improvements (#522) (#523)
* Extraction - Speed improvements (#522)

* Initial Plugin restructure

* Detectors to plugins. Detector speed improvements

* Re-implement dlib aligner, remove models, FAN to TF. Parallel processing

* Update manual, update convert, implement parallel/serial switching

* linting + fix cuda check (setup.py). requirements update keras 2.2.4

* Add extract size option. Fix dlib hog init

* GUI: Increase tooltip width

* Update alignment tool to support new DetectedFace

* Add skip existing faces option

* Fix sort tool to new plugin structure

* remove old align plugin

* fix convert -skip faces bug

* Fix convert skipping no faces frames

* Convert - draw onto transparent layer

* Fix blur threshold bug

* fix skip_faces convert bug

* Fix training
2018-10-27 10:12:08 +01:00

72 lines
2.6 KiB
Python

#!/usr/bin/env python3
""" Plugin loader for extract, training and model tasks """
import os
from importlib import import_module
class PluginLoader():
""" Plugin loader for extract, training and model tasks """
@staticmethod
def get_detector(name):
""" Return requested detector plugin """
return PluginLoader._import("extract.detect", name)
@staticmethod
def get_aligner(name):
""" Return requested detector plugin """
return PluginLoader._import("extract.align", name)
@staticmethod
def get_converter(name):
""" Return requested converter plugin """
return PluginLoader._import("Convert", "Convert_{0}".format(name))
@staticmethod
def get_model(name):
""" Return requested model plugin """
return PluginLoader._import("Model", "Model_{0}".format(name))
@staticmethod
def get_trainer(name):
""" Return requested trainer plugin """
return PluginLoader._import("Trainer", "Model_{0}".format(name))
@staticmethod
def _import(attr, name):
""" Import the plugin's module """
ttl = attr.split(".")[-1].title()
print("Loading {} from {} plugin...".format(ttl, name.title()))
attr = "model" if attr == "Trainer" else attr.lower()
mod = ".".join(("plugins", attr, name))
module = import_module(mod)
return getattr(module, ttl)
@staticmethod
def get_available_models():
""" Return a list of available models """
models = ()
modelpath = os.path.join(os.path.dirname(__file__), "model")
for modeldir in next(os.walk(modelpath))[1]:
if modeldir[0:6].lower() == 'model_':
models += (modeldir[6:],)
return models
@staticmethod
def get_available_extractors(extractor_type):
""" Return a list of available models """
extractpath = os.path.join(os.path.dirname(__file__),
"extract",
extractor_type)
extractors = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(extractpath)
if not item.name.startswith("_")
and item.name.endswith(".py")
and item.name != "manual.py")
return extractors
@staticmethod
def get_default_model():
""" Return the default model """
models = PluginLoader.get_available_models()
return 'Original' if 'Original' in models else models[0]