1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00

Bugfixes and minor changes

Fix: Occasional memory leak in convert
Fix: Dynamic config loading in Windows
Change: Better folder naming for snapshots
This commit is contained in:
torzdf 2019-06-11 17:26:43 +00:00
parent 0496bb7806
commit 71bcd863e1
6 changed files with 37 additions and 6 deletions

View file

@ -65,6 +65,24 @@ def get_image_paths(directory):
return dir_contents return dir_contents
def full_path_split(path):
""" Split a given path into all of it's separate components """
allparts = list()
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
allparts.insert(0, parts[0])
break
elif parts[1] == path: # sentinel for relative paths
allparts.insert(0, parts[1])
break
else:
path = parts[0]
allparts.insert(0, parts[1])
logger.trace("path: %s, allparts: %s", path, allparts)
return allparts
def cv2_read_img(filename, raise_error=False): def cv2_read_img(filename, raise_error=False):
""" Read an image with cv2 and check that an image was actually loaded. """ Read an image with cv2 and check that an image was actually loaded.
Logs an error if the image returned is None. or an error has occured. Logs an error if the image returned is None. or an error has occured.

View file

@ -8,6 +8,7 @@ import sys
from importlib import import_module from importlib import import_module
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -24,7 +25,7 @@ class Config(FaceswapConfig):
if not default_files: if not default_files:
continue continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0])) base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = dirpath.replace(base_path, "").replace("/", ".")[1:] import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1] plugin_type = import_path.split(".")[-1]
for filename in default_files: for filename in default_files:
self.load_module(filename, import_path, plugin_type) self.load_module(filename, import_path, plugin_type)

View file

@ -7,6 +7,7 @@ import sys
from importlib import import_module from importlib import import_module
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -23,7 +24,7 @@ class Config(FaceswapConfig):
if not default_files: if not default_files:
continue continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0])) base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = dirpath.replace(base_path, "").replace("/", ".")[1:] import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1] plugin_type = import_path.split(".")[-1]
for filename in default_files: for filename in default_files:
self.load_module(filename, import_path, plugin_type) self.load_module(filename, import_path, plugin_type)

View file

@ -9,6 +9,7 @@ from importlib import import_module
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
from lib.model.masks import get_available_masks from lib.model.masks import get_available_masks
from lib.utils import full_path_split
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -29,7 +30,7 @@ class Config(FaceswapConfig):
if not default_files: if not default_files:
continue continue
base_path = os.path.dirname(os.path.realpath(sys.argv[0])) base_path = os.path.dirname(os.path.realpath(sys.argv[0]))
import_path = dirpath.replace(base_path, "").replace("/", ".")[1:] import_path = ".".join(full_path_split(dirpath.replace(base_path, ""))[1:])
plugin_type = import_path.split(".")[-1] plugin_type = import_path.split(".")[-1]
for filename in default_files: for filename in default_files:
self.load_module(filename, import_path, plugin_type) self.load_module(filename, import_path, plugin_type)

View file

@ -426,7 +426,7 @@ class ModelBase():
""" Take a snapshot of the model at current state and back up """ """ Take a snapshot of the model at current state and back up """
logger.info("Saving snapshot") logger.info("Saving snapshot")
src = self.model_dir src = self.model_dir
dst = get_folder("{}_{}".format(self.model_dir, self.iterations)) dst = get_folder("{}_snapshot_{}_iters".format(self.model_dir, self.iterations))
for filename in os.listdir(src): for filename in os.listdir(src):
if filename.endswith(".bk"): if filename.endswith(".bk"):
continue continue

View file

@ -492,12 +492,19 @@ class Predict():
def predict_faces(self): def predict_faces(self):
""" Get detected faces from images """ """ Get detected faces from images """
faces_seen = 0 faces_seen = 0
consecutive_no_faces = 0
batch = list() batch = list()
while True: while True:
item = self.in_queue.get() item = self.in_queue.get()
if item != "EOF": if item != "EOF":
logger.trace("Got from queue: '%s'", item["filename"]) logger.trace("Got from queue: '%s'", item["filename"])
faces_count = len(item["detected_faces"]) faces_count = len(item["detected_faces"])
# Safety measure. If a large stream of frames appear that do not have faces,
# these will stack up into RAM. Keep a count of consecutive frames with no faces.
# If self.batchsize number of frames appear, force the current batch through
# to clear RAM.
consecutive_no_faces = consecutive_no_faces + 1 if faces_count == 0 else 0
self.faces_count += faces_count self.faces_count += faces_count
if faces_count > 1: if faces_count > 1:
self.verify_output = True self.verify_output = True
@ -509,8 +516,10 @@ class Predict():
faces_seen += faces_count faces_seen += faces_count
batch.append(item) batch.append(item)
if faces_seen < self.batchsize and item != "EOF": if item != "EOF" and (faces_seen < self.batchsize and
logger.trace("Continuing. Current batchsize: %s", faces_seen) consecutive_no_faces < self.batchsize):
logger.trace("Continuing. Current batchsize: %s, consecutive_no_faces: %s",
faces_seen, consecutive_no_faces)
continue continue
if batch: if batch:
@ -526,6 +535,7 @@ class Predict():
self.queue_out_frames(batch, predicted) self.queue_out_frames(batch, predicted)
consecutive_no_faces = 0
faces_seen = 0 faces_seen = 0
batch = list() batch = list()
if item == "EOF": if item == "EOF":