mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test
67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
""" UNET DFL face mask plugin
|
|
|
|
Architecture and Pre-Trained Model based on...
|
|
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
|
|
https://arxiv.org/abs/1801.05746
|
|
https://github.com/ternaus/TernausNet
|
|
|
|
Source Implementation and fine-tune training....
|
|
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/TernausNet.py
|
|
|
|
Model file sourced from...
|
|
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
|
|
"""
|
|
import logging
|
|
import typing as T
|
|
|
|
import numpy as np
|
|
from lib.model.session import KSession
|
|
from ._base import BatchType, Masker, MaskerBatch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Mask(Masker):
|
|
""" Neural network to process face image into a segmentation mask of the face """
|
|
def __init__(self, **kwargs) -> None:
|
|
git_model_id = 6
|
|
model_filename = "DFL_256_sigmoid_v1.h5"
|
|
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
|
|
self.model: KSession
|
|
self.name = "U-Net"
|
|
self.input_size = 256
|
|
self.vram = 3424
|
|
self.vram_warnings = 256
|
|
self.vram_per_batch = 80
|
|
self.batchsize = self.config["batch-size"]
|
|
self._storage_centering = "legacy"
|
|
|
|
def init_model(self) -> None:
|
|
assert self.name is not None and isinstance(self.model_path, str)
|
|
self.model = KSession(self.name,
|
|
self.model_path,
|
|
model_kwargs={},
|
|
allow_growth=self.config["allow_growth"],
|
|
exclude_gpus=self._exclude_gpus)
|
|
self.model.load_model()
|
|
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
|
|
dtype="float32")
|
|
self.model.predict(placeholder)
|
|
|
|
def process_input(self, batch: BatchType) -> None:
|
|
""" Compile the detected faces for prediction """
|
|
assert isinstance(batch, MaskerBatch)
|
|
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
|
|
for feed in batch.feed_faces], dtype="float32") / 255.0
|
|
logger.trace("feed shape: %s", batch.feed.shape) # type: ignore
|
|
|
|
def predict(self, feed: np.ndarray) -> np.ndarray:
|
|
""" Run model to get predictions """
|
|
retval = self.model.predict(feed)
|
|
assert isinstance(retval, np.ndarray)
|
|
return retval
|
|
|
|
def process_output(self, batch: BatchType) -> None:
|
|
""" Compile found faces for output """
|
|
return
|