mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
56 lines
2 KiB
Python
56 lines
2 KiB
Python
#!/usr/bin/env python3
|
|
""" UNET DFL face mask plugin
|
|
|
|
Architecture and Pre-Trained Model based on...
|
|
TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation
|
|
https://arxiv.org/abs/1801.05746
|
|
https://github.com/ternaus/TernausNet
|
|
|
|
Source Implementation and fine-tune training....
|
|
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/TernausNet.py
|
|
|
|
Model file sourced from...
|
|
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
|
|
"""
|
|
|
|
import numpy as np
|
|
from lib.model.session import KSession
|
|
from ._base import Masker, logger
|
|
|
|
|
|
class Mask(Masker):
|
|
""" Neural network to process face image into a segmentation mask of the face """
|
|
def __init__(self, **kwargs):
|
|
git_model_id = 6
|
|
model_filename = "DFL_256_sigmoid_v1.h5"
|
|
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)
|
|
self.name = "U-Net"
|
|
self.input_size = 256
|
|
self.vram = 3424
|
|
self.vram_warnings = 256
|
|
self.vram_per_batch = 80
|
|
self.batchsize = self.config["batch-size"]
|
|
|
|
def init_model(self):
|
|
self.model = KSession(self.name, self.model_path,
|
|
model_kwargs=dict(), allow_growth=self.config["allow_growth"])
|
|
self.model.load_model()
|
|
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
|
|
dtype="float32")
|
|
self.model.predict(placeholder)
|
|
|
|
def process_input(self, batch):
|
|
""" Compile the detected faces for prediction """
|
|
batch["feed"] = np.array([face.feed_face[..., :3]
|
|
for face in batch["detected_faces"]], dtype="float32") / 255.0
|
|
logger.trace("feed shape: %s", batch["feed"].shape)
|
|
return batch
|
|
|
|
def predict(self, batch):
|
|
""" Run model to get predictions """
|
|
batch["prediction"] = self.model.predict(batch["feed"])
|
|
return batch
|
|
|
|
def process_output(self, batch):
|
|
""" Compile found faces for output """
|
|
return batch
|