1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/plugins/extract/mask/unet_dfl.py

56 lines
2 KiB
Python

#!/usr/bin/env python3
""" UNET DFL face mask plugin
Architecture and Pre-Trained Model based on...
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
https://arxiv.org/abs/1801.05746
https://github.com/ternaus/TernausNet
Source Implementation and fine-tune training....
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/TernausNet.py
Model file sourced from...
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
"""
import numpy as np
from lib.model.session import KSession
from ._base import Masker, logger
class Mask(Masker):
""" Neural network to process face image into a segmentation mask of the face """
def __init__(self, **kwargs):
git_model_id = 6
model_filename = "DFL_256_sigmoid_v1.h5"
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
self.name = "U-Net"
self.input_size = 256
self.vram = 3424
self.vram_warnings = 256
self.vram_per_batch = 80
self.batchsize = self.config["batch-size"]
def init_model(self):
self.model = KSession(self.name, self.model_path,
model_kwargs=dict(), allow_growth=self.config["allow_growth"])
self.model.load_model()
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
dtype="float32")
self.model.predict(placeholder)
def process_input(self, batch):
""" Compile the detected faces for prediction """
batch["feed"] = np.array([face.feed_face[..., :3]
for face in batch["detected_faces"]], dtype="float32") / 255.0
logger.trace("feed shape: %s", batch["feed"].shape)
return batch
def predict(self, batch):
""" Run model to get predictions """
batch["prediction"] = self.model.predict(batch["feed"])
return batch
def process_output(self, batch):
""" Compile found faces for output """
return batch