1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/extract/mask/unet_dfl.py
torzdf 6a3b674bef
Rebase code (#1326)
* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
2023-06-27 11:27:47 +01:00

67 lines
2.5 KiB
Python

#!/usr/bin/env python3
""" UNET DFL face mask plugin
Architecture and Pre-Trained Model based on...
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
https://arxiv.org/abs/1801.05746
https://github.com/ternaus/TernausNet
Source Implementation and fine-tune training....
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/TernausNet.py
Model file sourced from...
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
"""
import logging
import typing as T
import numpy as np
from lib.model.session import KSession
from ._base import BatchType, Masker, MaskerBatch
logger = logging.getLogger(__name__)
class Mask(Masker):
""" Neural network to process face image into a segmentation mask of the face """
def __init__(self, **kwargs) -> None:
git_model_id = 6
model_filename = "DFL_256_sigmoid_v1.h5"
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
self.model: KSession
self.name = "U-Net"
self.input_size = 256
self.vram = 3424
self.vram_warnings = 256
self.vram_per_batch = 80
self.batchsize = self.config["batch-size"]
self._storage_centering = "legacy"
def init_model(self) -> None:
assert self.name is not None and isinstance(self.model_path, str)
self.model = KSession(self.name,
self.model_path,
model_kwargs={},
allow_growth=self.config["allow_growth"],
exclude_gpus=self._exclude_gpus)
self.model.load_model()
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
dtype="float32")
self.model.predict(placeholder)
def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """
assert isinstance(batch, MaskerBatch)
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
for feed in batch.feed_faces], dtype="float32") / 255.0
logger.trace("feed shape: %s", batch.feed.shape) # type: ignore
def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """
retval = self.model.predict(feed)
assert isinstance(retval, np.ndarray)
return retval
def process_output(self, batch: BatchType) -> None:
""" Compile found faces for output """
return