1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00

Rebase code (#1326)

* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
This commit is contained in:
torzdf 2023-06-27 11:27:47 +01:00 committed by GitHub
parent e4ba12ad2a
commit 6a3b674bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
130 changed files with 3035 additions and 3028 deletions

View file

@ -8,17 +8,85 @@ on:
- "**/README.md"
jobs:
build_linux:
build_conda:
name: conda (${{ matrix.os }}, ${{ matrix.backend }})
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash -el {0}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
backend: ["nvidia", "cpu"]
include:
- os: "ubuntu-latest"
backend: "rocm"
- os: "windows-latest"
backend: "directml"
exclude:
# pynvx does not currently build for Python3.10 and without CUDA it may not build at all
- os: "macos-latest"
backend: "nvidia"
steps:
- uses: actions/checkout@v3
- name: Set cache date
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV
- name: Cache conda
uses: actions/cache@v3
env:
# Increase this value to manually reset cache
CACHE_NUMBER: 1
REQ_FILE: ./requirements/requirements_${{ matrix.backend }}.txt
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-${{ matrix.backend }}-conda-${{ env.CACHE_NUMBER }}-${{ env.DATE }}-${{ hashFiles('./requirements/requirements.txt', env.REQ_FILE) }}
- name: Set up Conda
uses: conda-incubator/setup-miniconda@v2
with:
python-version: "3.10"
auto-update-conda: true
activate-environment: faceswap
- name: Conda info
run: conda info && conda list
- name: Install
run: |
python setup.py --installer --${{ matrix.backend }}
pip install flake8 pylint mypy pytest pytest-mock wheel pytest-xvfb
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --select=E9,F63,F7,F82 --show-source
flake8 . --exit-zero
- name: MyPy Typing
continue-on-error: true
run: |
mypy .
- name: SysInfo
run: python -c "from lib.sysinfo import sysinfo ; print(sysinfo)"
- name: Simple Tests
# These backends will fail as GPU drivers not available
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml'
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
- name: End to End Tests
# These backends will fail as GPU drivers not available
# macOS fails on first extract test with 'died with <Signals.SIGSEGV: 11>'
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml' && matrix.os != 'macos-latest'
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
build_linux:
name: "pip (ubuntu-latest, ${{ matrix.backend }})"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.10"]
backend: ["cpu"]
include:
- kbackend: "tensorflow"
backend: "cpu"
- backend: "cpu"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
@ -33,6 +101,8 @@ jobs:
pip install flake8 pylint mypy pytest pytest-mock pytest-xvfb wheel
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
- name: List installed packages
run: pip freeze
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@ -45,17 +115,18 @@ jobs:
mypy .
- name: Simple Tests
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" py.test -v tests/;
FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
- name: End to End Tests
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" python tests/simple_tests.py;
FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
build_windows:
name: "pip (windows-latest, ${{ matrix.backend }})"
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9"]
python-version: ["3.10"]
backend: ["cpu", "directml"]
include:
- backend: "cpu"
@ -74,6 +145,8 @@ jobs:
pip install flake8 pylint mypy pytest pytest-mock wheel
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
- name: List installed packages
run: pip freeze
- name: Set Backend EnvVar
run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append
- name: Lint with flake8

View file

@ -12,7 +12,7 @@ DIR_CONDA="$HOME/miniconda3"
CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda"
CONDA_TO_PATH=false
ENV_NAME="faceswap"
PYENV_VERSION="3.9"
PYENV_VERSION="3.10"
DIR_FACESWAP="$HOME/faceswap"
VERSION="nvidia"
@ -363,7 +363,7 @@ delete_env() {
}
create_env() {
# Create Python 3.8 env for faceswap
# Create Python 3.10 env for faceswap
delete_env
info "Creating Conda Virtual Environment..."
yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -q python="$PYENV_VERSION" -y

View file

@ -22,7 +22,7 @@ InstallDir $PROFILE\faceswap
# Install cli flags
!define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3"
!define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}"
!define flagsEnv "-y python=3.9"
!define flagsEnv "-y python=3.10"
# Folders
Var ProgramData

View file

@ -7,9 +7,9 @@ version: 2
# Set the version of Python and other tools you might need
build:
os: ubuntu-20.04
os: ubuntu-22.04
tools:
python: "3.8"
python: "3.10"
# Build documentation in the docs/ directory with Sphinx
sphinx:

View file

@ -1,19 +1,19 @@
FROM tensorflow/tensorflow:2.8.2
FROM ubuntu:22.04
# To disable tzdata and others from asking for input
ENV DEBIAN_FRONTEND noninteractive
ENV FACESWAP_BACKEND cpu
RUN apt-get update -qq -y \
&& apt-get install -y software-properties-common \
&& add-apt-repository -y ppa:jonathonf/ffmpeg-4 \
&& apt-get update -qq -y \
&& apt-get install -y libsm6 libxrender1 libxext-dev python3-tk ffmpeg git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update -qq -y
RUN apt-get upgrade -y
RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
COPY ./requirements/_requirements_base.txt /opt/
RUN pip3 install --upgrade pip
RUN pip3 --no-cache-dir install -r /opt/_requirements_base.txt && rm /opt/_requirements_base.txt
RUN ln -s $(which python3) /usr/local/bin/python
RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
WORKDIR "/faceswap"
RUN python -m pip install --upgrade pip
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_cpu.txt
WORKDIR "/srv"
CMD ["/bin/bash"]

View file

@ -1,29 +1,19 @@
FROM nvidia/cuda:11.7.0-runtime-ubuntu18.04
ARG DEBIAN_FRONTEND=noninteractive
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
#install python3.8
RUN apt-get update
RUN apt-get install software-properties-common -y
RUN add-apt-repository ppa:deadsnakes/ppa -y
RUN apt-get update
RUN apt-get install python3.8 -y
RUN apt-get install python3.8-distutils -y
RUN apt-get install python3.8-tk -y
RUN apt-get install curl -y
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
RUN python3.8 get-pip.py
RUN rm get-pip.py
ENV DEBIAN_FRONTEND=noninteractive
ENV FACESWAP_BACKEND nvidia
# install requirements
RUN apt-get install ffmpeg git -y
COPY ./requirements/_requirements_base.txt /opt/
COPY ./requirements/requirements_nvidia.txt /opt/
RUN python3.8 -m pip --no-cache-dir install -r /opt/requirements_nvidia.txt && rm /opt/_requirements_base.txt && rm /opt/requirements_nvidia.txt
RUN apt-get update -qq -y
RUN apt-get upgrade -y
RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
RUN python3.8 -m pip install jupyter matplotlib tqdm
RUN python3.8 -m pip install jupyter_http_over_ws
RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN alias python=python3.8
RUN echo "alias python=python3.8" >> /root/.bashrc
WORKDIR "/notebooks"
CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"]
RUN ln -s $(which python3) /usr/local/bin/python
RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
WORKDIR "/faceswap"
RUN python -m pip install --upgrade pip
RUN python -m pip install --upgrade pip
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_nvidia.txt
CMD ["/bin/bash"]

View file

@ -39,12 +39,9 @@
- [Setup](#setup-2)
- [About some of the options](#about-some-of-the-options)
- [Docker Install Guide](#docker-install-guide)
- [Docker General](#docker-general)
- [CUDA with Docker in 20 minutes.](#cuda-with-docker-in-20-minutes)
- [CUDA with Docker on Arch Linux](#cuda-with-docker-on-arch-linux)
- [Install docker](#install-docker)
- [A successful setup log, without docker.](#a-successful-setup-log-without-docker)
- [Run the project](#run-the-project)
- [Docker CPU](#docker-cpu)
- [Docker Nvidia](#docker-nvidia)
- [Run the project](#run-the-project)
- [Notes](#notes)
# Prerequisites
@ -115,7 +112,7 @@ Reboot your PC, so that everything you have just installed gets registered.
- Select "Create" at the bottom
- In the pop up:
- Give it the name: faceswap
- **IMPORTANT**: Select python version 3.8
- **IMPORTANT**: Select python version 3.10
- Hit "Create" (NB: This may take a while as it will need to download Python)
![Anaconda virtual env setup](https://i.imgur.com/CLIDDfa.png)
@ -195,7 +192,7 @@ $ source ~/miniforge3/bin/activate
## Setup
### Create and Activate the Environment
```sh
$ conda create --name faceswap python=3.9
$ conda create --name faceswap python=3.10
$ conda activate faceswap
```
@ -225,7 +222,7 @@ Obtain git for your distribution from the [git website](https://git-scm.com/down
The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project.
- MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html)
Alternatively you can install Python (>= 3.7-3.9 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu).
Alternatively you can install Python (3.10 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu).
- Python distributions:
- apt/yum install python3 (Linux)
- [Installer](https://www.python.org/downloads/release/python-368/) (Windows)
@ -260,153 +257,83 @@ If setup fails for any reason you can still manually install the packages listed
# Docker Install Guide
## Docker General
<details>
<summary>Click to expand!</summary>
This Faceswap repo contains Docker build scripts for CPU and Nvidia backends. The scripts will set up a Docker container for you and install the latest version of the Faceswap software.
### CUDA with Docker in 20 minutes.
You must first ensure that Docker is installed and running on your system. Follow the guide for downloading and installing Docker from their website:
1. Install Docker
https://www.docker.com/community-edition
- https://www.docker.com/get-started
2. Install Nvidia-Docker & Restart Docker Service
https://github.com/NVIDIA/nvidia-docker
Once Docker is installed and running, follow the relevant steps for your chosen backend
## Docker CPU
To run the CPU version of Faceswap follow these steps:
3. Build Docker Image For faceswap
```bash
docker build -t deepfakes-gpu -f Dockerfile.gpu .
```
4. Mount faceswap volume and Run it
a). without `gui.tools.py` gui not working.
```bash
nvidia-docker run --rm -it -p 8888:8888 \
--hostname faceswap-gpu --name faceswap-gpu \
-v /opt/faceswap:/srv \
deepfakes-gpu
```
b). with gui. tools.py gui working.
Enable local access to X11 server
```bash
xhost +local:
1. Build the Docker image For faceswap:
```
Enable nvidia device if working under bumblebee
```bash
echo ON > /proc/acpi/bbswitch
docker build \
-t faceswap-cpu \
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.cpu
```
2. Launch and enter the Faceswap container:
Create container
```bash
nvidia-docker run -p 8888:8888 \
--hostname faceswap-gpu --name faceswap-gpu \
-v /opt/faceswap:/srv \
a. For the **headless/command line** version of Faceswap run:
```
docker run --rm -it faceswap-cpu
```
You can then execute faceswap the standard way:
```
python faceswap.py --help
```
b. For the **GUI** version of Faceswap run:
```
xhost +local: && \
docker run --rm -it \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=unix$DISPLAY \
-e AUDIO_GID=`getent group audio | cut -d: -f3` \
-e VIDEO_GID=`getent group video | cut -d: -f3` \
-e GID=`id -g` \
-e UID=`id -u` \
deepfakes-gpu
```
Open a new terminal to interact with the project
```bash
docker exec -it deepfakes-gpu /bin/bash
```
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt)
```bash
python3.8 /srv/faceswap.py gui
```
</details>
## CUDA with Docker on Arch Linux
<details>
<summary>Click to expand!</summary>
### Install docker
```bash
sudo pacman -S docker
```
The steps are same but Arch linux doesn't use nvidia-docker
create container
```bash
docker run -p 8888:8888 --gpus all --privileged -v /dev:/dev \
--hostname faceswap-gpu --name faceswap-gpu \
-v /mnt/hdd2/faceswap:/srv \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=unix$DISPLAY \
-e AUDIO_GID=`getent group audio | cut -d: -f3` \
-e VIDEO_GID=`getent group video | cut -d: -f3` \
-e GID=`id -g` \
-e UID=`id -u` \
deepfakes-gpu
```
Open a new terminal to interact with the project
```bash
docker exec -it deepfakes-gpu /bin/bash
```
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt)
**With `gui.tools.py` gui working.**
Enable local access to X11 server
```bash
xhost +local:
```
```bash
python3.8 /srv/faceswap.py gui
-e DISPLAY=${DISPLAY} \
faceswap-cpu
```
You can then launch the GUI with
```
python faceswap.py gui
```
## Docker Nvidia
To build the NVIDIA GPU version of Faceswap, follow these steps:
</details>
1. Nvidia Docker builds need extra resources to provide the Docker container with access to your GPU.
---
## A successful setup log, without docker.
a. Follow the instructions to install and apply the `Nvidia Container Toolkit` for your distribution from:
- https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
b. If Docker is already running, restart it to pick up the changes made by the Nvidia Container Toolkit.
2. Build the Docker image For faceswap
```
INFO The tool provides tips for installation
and installs required python packages
INFO Setup in Linux 4.14.39-1-MANJARO
INFO Installed Python: 3.7.5 64bit
INFO Installed PIP: 10.0.1
Enable Docker? [Y/n] n
INFO Docker Disabled
Enable CUDA? [Y/n]
INFO CUDA Enabled
INFO CUDA version: 9.1
INFO cuDNN version: 7
WARNING Tensorflow has no official prebuild for CUDA 9.1 currently.
To continue, You have to build your own tensorflow-gpu.
Help: https://www.tensorflow.org/install/install_sources
Are System Dependencies met? [y/N] y
INFO Installing Missing Python Packages...
INFO Installing tensorflow-gpu
......
INFO Installing tqdm
INFO Installing matplotlib
INFO All python3 dependencies are met.
You are good to go.
docker build \
-t faceswap-gpu \
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.gpu
```
1. Launch and enter the Faceswap container:
## Run the project
a. For the **headless/command line** version of Faceswap run:
```
docker run --runtime=nvidia --rm -it faceswap-gpu
```
You can then execute faceswap the standard way:
```
python faceswap.py --help
```
b. For the **GUI** version of Faceswap run:
```
xhost +local: && \
docker run --runtime=nvidia --rm -it \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=${DISPLAY} \
faceswap-gpu
```
You can then launch the GUI with
```
python faceswap.py gui
```
# Run the project
Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options.
```bash

View file

@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath('../'))
sys.setrecursionlimit(1500)
MOCK_MODULES = ["plaidml", "pynvx", "ctypes.windll", "comtypes"]
MOCK_MODULES = ["pynvx", "ctypes.windll", "comtypes"]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()

View file

@ -1,25 +1,21 @@
# NB Do not install from this requirements file
# It is for documentation purposes only
sphinx==5.0.2
sphinx_rtd_theme==1.0.0
tqdm==4.64
psutil==5.8.0
numexpr>=2.8.3
numpy>=1.18.0
opencv-python>=4.5.5.0
pillow==8.3.1
scikit-learn>=1.0.2
fastcluster>=1.2.4
matplotlib==3.5.1
numexpr
imageio==2.9.0
imageio-ffmpeg==0.4.7
ffmpy==0.2.3
nvidia-ml-py<11.515
plaidml==0.7.0
sphinx==7.0.1
sphinx_rtd_theme==1.2.2
tqdm==4.65
psutil==5.9.0
numexpr>=2.8.4
numpy>=1.25.0
opencv-python>=4.7.0.0
pillow==9.4.0
scikit-learn>=1.2.2
fastcluster>=1.2.6
matplotlib==3.7.1
imageio==2.31.1
imageio-ffmpeg==0.4.8
ffmpy==0.3.0
nvidia-ml-py==11.525
pytest==7.2.0
pytest-mock==3.10.0
tensorflow>=2.8.0,<2.9.0
tensorflow_probability<0.17
typing-extensions>=4.0.0
tensorflow>=2.10.0,<2.11.0

View file

@ -16,9 +16,8 @@ from lib.config import generate_configs # pylint:disable=wrong-import-position
_LANG = gettext.translation("faceswap", localedir="locales", fallback=True)
_ = _LANG.gettext
if sys.version_info < (3, 7):
raise ValueError("This program requires at least python3.7")
if sys.version_info < (3, 10):
raise ValueError("This program requires at least python 3.10")
_PARSER = cli_args.FullHelpArgumentParser()

View file

@ -3,22 +3,14 @@
from dataclasses import dataclass, field
import logging
import sys
import typing as T
from threading import Lock
from typing import cast, Dict, Optional, Tuple
import cv2
import numpy as np
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
CenteringType = Literal["face", "head", "legacy"]
CenteringType = T.Literal["face", "head", "legacy"]
_MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748],
[0.300643, 0.034489], [0.403270, 0.077391], [0.596729, 0.077391],
@ -65,10 +57,10 @@ _MEAN_FACE_3D = np.array([[4.056931, -11.432347, 1.636229], # 8 chin LL
[0.0, -8.601736, 6.097667], # 45 mouth bottom C
[0.589441, -8.443925, 6.109526]]) # 44 mouth bottom L
_EXTRACT_RATIOS = dict(legacy=0.375, face=0.5, head=0.625)
_EXTRACT_RATIOS = {"legacy": 0.375, "face": 0.5, "head": 0.625}
def get_matrix_scaling(matrix: np.ndarray) -> Tuple[int, int]:
def get_matrix_scaling(matrix: np.ndarray) -> tuple[int, int]:
""" Given a matrix, return the cv2 Interpolation method and inverse interpolation method for
applying the matrix on an image.
@ -213,6 +205,149 @@ def get_centered_size(source_centering: CenteringType,
return retval
class PoseEstimate():
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmarks aligned to 0.0 - 1.0 range
References
----------
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
"""
def __init__(self, landmarks: np.ndarray) -> None:
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
self._xyz_2d: np.ndarray | None = None
self._camera_matrix = self._get_camera_matrix()
self._rotation, self._translation = self._solve_pnp(landmarks)
self._offset = self._get_offset()
self._pitch_yaw_roll: tuple[float, float, float] = (0, 0, 0)
@property
def xyz_2d(self) -> np.ndarray:
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
if self._xyz_2d is None:
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
[0., 6., -2.3],
[0., 0., 3.7]]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
self._xyz_2d = xyz - self._offset["head"]
return self._xyz_2d
@property
def offset(self) -> dict[CenteringType, np.ndarray]:
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
from the center of the face (between the eyes) or center of the head (middle of skull)
rather than the nose area. """
return self._offset
@property
def pitch(self) -> float:
""" float: The pitch of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[0]
@property
def yaw(self) -> float:
""" float: The yaw of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[1]
@property
def roll(self) -> float:
""" float: The roll of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[2]
def _get_pitch_yaw_roll(self) -> None:
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
proj_matrix = np.zeros((3, 4), dtype="float32")
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
self._pitch_yaw_roll = T.cast(tuple[float, float, float], tuple(euler.squeeze()))
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
@classmethod
def _get_camera_matrix(cls) -> np.ndarray:
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
Returns
-------
:class:`numpy.ndarray`
An estimated camera matrix
"""
focal_length = 4
camera_matrix = np.array([[focal_length, 0, 0.5],
[0, focal_length, 0.5],
[0, 0, 1]], dtype="double")
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
return camera_matrix
def _solve_pnp(self, landmarks: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""" Solve the Perspective-n-Point for the given landmarks.
Takes 2D landmarks in world space and estimates the rotation and translation vectors
in 3D space.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmark co-ordinates relating to the original frame
Returns
-------
rotation: :class:`numpy.ndarray`
The solved rotation vector
translation: :class:`numpy.ndarray`
The solved translation vector
"""
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
points,
self._camera_matrix,
self._distortion_coefficients,
flags=cv2.SOLVEPNP_ITERATIVE)
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
points, rotation, translation)
return rotation, translation
def _get_offset(self) -> dict[CenteringType, np.ndarray]:
""" Obtain the offset between the original center of the extracted face to the new center
of the head in 2D space.
Returns
-------
:class:`numpy.ndarray`
The x, y offset of the new center from the old center.
"""
offset: dict[CenteringType, np.ndarray] = {"legacy": np.array([0.0, 0.0])}
points: dict[T.Literal["face", "head"], tuple[float, ...]] = {"head": (0.0, 0.0, -2.3),
"face": (0.0, -1.5, 4.2)}
for key, pnts in points.items():
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
logger.trace("center %s: %s", key, center) # type: ignore
offset[key] = center - (0.5, 0.5)
logger.trace("offset: %s", offset) # type: ignore
return offset
@dataclass
class _FaceCache: # pylint:disable=too-many-instance-attributes
""" Cache for storing items related to a single aligned face.
@ -251,19 +386,19 @@ class _FaceCache: # pylint:disable=too-many-instance-attributes
cropped_slices: dict, optional
The slices for an input full head image and output cropped image. Default: `{}`
"""
pose: Optional["PoseEstimate"] = None
original_roi: Optional[np.ndarray] = None
landmarks: Optional[np.ndarray] = None
landmarks_normalized: Optional[np.ndarray] = None
pose: PoseEstimate | None = None
original_roi: np.ndarray | None = None
landmarks: np.ndarray | None = None
landmarks_normalized: np.ndarray | None = None
average_distance: float = 0.0
relative_eye_mouth_position: float = 0.0
adjusted_matrix: Optional[np.ndarray] = None
interpolators: Tuple[int, int] = (0, 0)
cropped_roi: Dict[CenteringType, np.ndarray] = field(default_factory=dict)
cropped_slices: Dict[CenteringType, Dict[Literal["in", "out"],
Tuple[slice, slice]]] = field(default_factory=dict)
adjusted_matrix: np.ndarray | None = None
interpolators: tuple[int, int] = (0, 0)
cropped_roi: dict[CenteringType, np.ndarray] = field(default_factory=dict)
cropped_slices: dict[CenteringType, dict[T.Literal["in", "out"],
tuple[slice, slice]]] = field(default_factory=dict)
_locks: Dict[str, Lock] = field(default_factory=dict)
_locks: dict[str, Lock] = field(default_factory=dict)
def __post_init__(self):
""" Initialize the locks for the class parameters """
@ -322,11 +457,11 @@ class AlignedFace():
"""
def __init__(self,
landmarks: np.ndarray,
image: Optional[np.ndarray] = None,
image: np.ndarray | None = None,
centering: CenteringType = "face",
size: int = 64,
coverage_ratio: float = 1.0,
dtype: Optional[str] = None,
dtype: str | None = None,
is_aligned: bool = False,
is_legacy: bool = False) -> None:
logger.trace("Initializing: %s (image shape: %s, centering: '%s', " # type: ignore
@ -340,9 +475,9 @@ class AlignedFace():
self._dtype = dtype
self._is_aligned = is_aligned
self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head"
self._matrices = dict(legacy=_umeyama(landmarks[17:], _MEAN_FACE, True)[0:2],
face=np.array([]),
head=np.array([]))
self._matrices = {"legacy": _umeyama(landmarks[17:], _MEAN_FACE, True)[0:2],
"face": np.array([]),
"head": np.array([])}
self._padding = self._padding_from_coverage(size, coverage_ratio)
self._cache = _FaceCache()
@ -353,7 +488,7 @@ class AlignedFace():
self._face if self._face is None else self._face.shape)
@property
def centering(self) -> Literal["legacy", "head", "face"]:
def centering(self) -> T.Literal["legacy", "head", "face"]:
""" str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """
return self._centering
@ -382,7 +517,7 @@ class AlignedFace():
return self._matrices[self._centering]
@property
def pose(self) -> "PoseEstimate":
def pose(self) -> PoseEstimate:
""" :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """
with self._cache.lock("pose"):
if self._cache.pose is None:
@ -405,7 +540,7 @@ class AlignedFace():
return self._cache.adjusted_matrix
@property
def face(self) -> Optional[np.ndarray]:
def face(self) -> np.ndarray | None:
""" :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified
:attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided
then an the attribute will return ``None``. """
@ -450,7 +585,7 @@ class AlignedFace():
return self._cache.landmarks_normalized
@property
def interpolators(self) -> Tuple[int, int]:
def interpolators(self) -> tuple[int, int]:
""" tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """
with self._cache.lock("interpolators"):
if not any(self._cache.interpolators):
@ -487,7 +622,7 @@ class AlignedFace():
return self._cache.relative_eye_mouth_position
@classmethod
def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> Dict[CenteringType, int]:
def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> dict[CenteringType, int]:
""" Return the image padding for a face from coverage_ratio set against a
pre-padded training image.
@ -504,7 +639,7 @@ class AlignedFace():
The padding required, in pixels for 'head', 'face' and 'legacy' face types
"""
retval = {_type: round((size * (coverage_ratio - (1 - _EXTRACT_RATIOS[_type]))) / 2)
for _type in get_args(Literal["legacy", "face", "head"])}
for _type in T.get_args(T.Literal["legacy", "face", "head"])}
logger.trace(retval) # type: ignore
return retval
@ -532,7 +667,7 @@ class AlignedFace():
invert, points, retval)
return retval
def extract_face(self, image: Optional[np.ndarray]) -> Optional[np.ndarray]:
def extract_face(self, image: np.ndarray | None) -> np.ndarray | None:
""" Extract the face from a source image and populate :attr:`face`. If an image is not
provided then ``None`` is returned.
@ -605,7 +740,7 @@ class AlignedFace():
def _get_cropped_slices(self,
image_size: int,
target_size: int,
) -> Dict[Literal["in", "out"], Tuple[slice, slice]]:
) -> dict[T.Literal["in", "out"], tuple[slice, slice]]:
""" Obtain the slices to turn a full head extract into an alternatively centered extract.
Parameters
@ -676,149 +811,6 @@ class AlignedFace():
return self._cache.cropped_roi[centering]
class PoseEstimate():
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmarks aligned to 0.0 - 1.0 range
References
----------
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
"""
def __init__(self, landmarks: np.ndarray) -> None:
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
self._xyz_2d: Optional[np.ndarray] = None
self._camera_matrix = self._get_camera_matrix()
self._rotation, self._translation = self._solve_pnp(landmarks)
self._offset = self._get_offset()
self._pitch_yaw_roll: Tuple[float, float, float] = (0, 0, 0)
@property
def xyz_2d(self) -> np.ndarray:
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
if self._xyz_2d is None:
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
[0., 6., -2.3],
[0., 0., 3.7]]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
self._xyz_2d = xyz - self._offset["head"]
return self._xyz_2d
@property
def offset(self) -> Dict[CenteringType, np.ndarray]:
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
from the center of the face (between the eyes) or center of the head (middle of skull)
rather than the nose area. """
return self._offset
@property
def pitch(self) -> float:
""" float: The pitch of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[0]
@property
def yaw(self) -> float:
""" float: The yaw of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[1]
@property
def roll(self) -> float:
""" float: The roll of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[2]
def _get_pitch_yaw_roll(self) -> None:
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
proj_matrix = np.zeros((3, 4), dtype="float32")
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
self._pitch_yaw_roll = cast(Tuple[float, float, float], tuple(euler.squeeze()))
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
@classmethod
def _get_camera_matrix(cls) -> np.ndarray:
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
Returns
-------
:class:`numpy.ndarray`
An estimated camera matrix
"""
focal_length = 4
camera_matrix = np.array([[focal_length, 0, 0.5],
[0, focal_length, 0.5],
[0, 0, 1]], dtype="double")
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
return camera_matrix
def _solve_pnp(self, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
""" Solve the Perspective-n-Point for the given landmarks.
Takes 2D landmarks in world space and estimates the rotation and translation vectors
in 3D space.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmark co-ordinates relating to the original frame
Returns
-------
rotation: :class:`numpy.ndarray`
The solved rotation vector
translation: :class:`numpy.ndarray`
The solved translation vector
"""
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
points,
self._camera_matrix,
self._distortion_coefficients,
flags=cv2.SOLVEPNP_ITERATIVE)
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
points, rotation, translation)
return rotation, translation
def _get_offset(self) -> Dict[CenteringType, np.ndarray]:
""" Obtain the offset between the original center of the extracted face to the new center
of the head in 2D space.
Returns
-------
:class:`numpy.ndarray`
The x, y offset of the new center from the old center.
"""
offset: Dict[CenteringType, np.ndarray] = dict(legacy=np.array([0.0, 0.0]))
points: Dict[Literal["face", "head"], Tuple[float, ...]] = dict(head=(0.0, 0.0, -2.3),
face=(0.0, -1.5, 4.2))
for key, pnts in points.items():
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
logger.trace("center %s: %s", key, center) # type: ignore
offset[key] = center - (0.5, 0.5)
logger.trace("offset: %s", offset) # type: ignore
return offset
def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray:
"""Estimate N-D similarity transformation with or without scaling.
@ -866,24 +858,24 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
if np.linalg.det(A) < 0:
d[dim - 1] = -1
T = np.eye(dim + 1, dtype=np.double)
retval = np.eye(dim + 1, dtype=np.double)
U, S, V = np.linalg.svd(A)
# Eq. (40) and (43).
rank = np.linalg.matrix_rank(A)
if rank == 0:
return np.nan * T
return np.nan * retval
if rank == dim - 1:
if np.linalg.det(U) * np.linalg.det(V) > 0:
T[:dim, :dim] = U @ V
retval[:dim, :dim] = U @ V
else:
s = d[dim - 1]
d[dim - 1] = -1
T[:dim, :dim] = U @ np.diag(d) @ V
retval[:dim, :dim] = U @ np.diag(d) @ V
d[dim - 1] = s
else:
T[:dim, :dim] = U @ np.diag(d) @ V
retval[:dim, :dim] = U @ np.diag(d) @ V
if estimate_scale:
# Eq. (41) and (42).
@ -891,7 +883,7 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
else:
scale = 1.0
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
T[:dim, :dim] *= scale
retval[:dim, dim] = dst_mean - scale * (retval[:dim, :dim] @ src_mean.T)
retval[:dim, :dim] *= scale
return T
return retval

View file

@ -1,24 +1,19 @@
#!/usr/bin/env python3
""" Alignments file functions for reading, writing and manipulating the data stored in a
serialized alignments file. """
from __future__ import annotations
import logging
import os
import sys
import typing as T
from datetime import datetime
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np
from lib.serializer import get_serializer, get_serializer_from_filename
from lib.utils import FaceswapError
if sys.version_info < (3, 8):
from typing_extensions import TypedDict
else:
from typing import TypedDict
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
from .aligned_face import CenteringType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -35,49 +30,49 @@ _VERSION = 2.3
# TODO Convert these to Dataclasses
class MaskAlignmentsFileDict(TypedDict):
class MaskAlignmentsFileDict(T.TypedDict):
""" Typed Dictionary for storing Masks. """
mask: bytes
affine_matrix: Union[List[float], np.ndarray]
affine_matrix: list[float] | np.ndarray
interpolator: int
stored_size: int
stored_centering: "CenteringType"
stored_centering: CenteringType
class PNGHeaderAlignmentsDict(TypedDict):
class PNGHeaderAlignmentsDict(T.TypedDict):
""" Base Dictionary for storing a single faces' Alignment Information in Alignments files and
PNG Headers. """
x: int
y: int
w: int
h: int
landmarks_xy: Union[List[float], np.ndarray]
mask: Dict[str, MaskAlignmentsFileDict]
identity: Dict[str, List[float]]
landmarks_xy: list[float] | np.ndarray
mask: dict[str, MaskAlignmentsFileDict]
identity: dict[str, list[float]]
class AlignmentFileDict(PNGHeaderAlignmentsDict):
""" Typed Dictionary for storing a single faces' Alignment Information in alignments files. """
thumb: Optional[np.ndarray]
thumb: np.ndarray | None
class PNGHeaderSourceDict(TypedDict):
class PNGHeaderSourceDict(T.TypedDict):
""" Dictionary for storing additional meta information in PNG headers """
alignments_version: float
original_filename: str
face_index: int
source_filename: str
source_is_video: bool
source_frame_dims: Optional[Tuple[int, int]]
source_frame_dims: tuple[int, int] | None
class AlignmentDict(TypedDict):
class AlignmentDict(T.TypedDict):
""" Dictionary for holding all of the alignment information within a single alignment file """
faces: List[AlignmentFileDict]
video_meta: Dict[str, Union[float, int]]
faces: list[AlignmentFileDict]
video_meta: dict[str, float | int]
class PNGHeaderDict(TypedDict):
class PNGHeaderDict(T.TypedDict):
""" Dictionary for storing all alignment and meta information in PNG Headers """
alignments: PNGHeaderAlignmentsDict
source: PNGHeaderSourceDict
@ -135,7 +130,7 @@ class Alignments():
return self._io.file
@property
def data(self) -> Dict[str, AlignmentDict]:
def data(self) -> dict[str, AlignmentDict]:
""" dict: The loaded alignments :attr:`file` in dictionary form. """
return self._data
@ -146,7 +141,7 @@ class Alignments():
return self._io.have_alignments_file
@property
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]:
def hashes_to_frame(self) -> dict[str, dict[str, int]]:
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
that the hash corresponds to.
@ -158,7 +153,7 @@ class Alignments():
return self._legacy.hashes_to_frame
@property
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]:
def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
corresponds to. The structure of the dictionary is:
@ -170,10 +165,10 @@ class Alignments():
return self._legacy.hashes_to_alignment
@property
def mask_summary(self) -> Dict[str, int]:
def mask_summary(self) -> dict[str, int]:
""" dict: The mask type names stored in the alignments :attr:`data` as key with the number
of faces which possess the mask type as value. """
masks: Dict[str, int] = {}
masks: dict[str, int] = {}
for val in self._data.values():
for face in val["faces"]:
if face.get("mask", None) is None:
@ -183,21 +178,20 @@ class Alignments():
return masks
@property
def video_meta_data(self) -> Dict[str, Optional[Union[List[int], List[float]]]]:
def video_meta_data(self) -> dict[str, list[int] | list[float] | None]:
""" dict: The frame meta data stored in the alignments file. If data does not exist in the
alignments file then ``None`` is returned for each Key """
retval: Dict[str, Optional[Union[List[int],
List[float]]]] = dict(pts_time=None, keyframes=None)
pts_time: List[float] = []
keyframes: List[int] = []
retval: dict[str, list[int] | list[float] | None] = {"pts_time": None, "keyframes": None}
pts_time: list[float] = []
keyframes: list[int] = []
for idx, key in enumerate(sorted(self.data)):
if not self.data[key].get("video_meta", {}):
return retval
meta = self.data[key]["video_meta"]
pts_time.append(cast(float, meta["pts_time"]))
pts_time.append(T.cast(float, meta["pts_time"]))
if meta["keyframe"]:
keyframes.append(idx)
retval = dict(pts_time=pts_time, keyframes=keyframes)
retval = {"pts_time": pts_time, "keyframes": keyframes}
return retval
@property
@ -211,7 +205,7 @@ class Alignments():
""" float: The alignments file version number. """
return self._io.version
def _load(self) -> Dict[str, AlignmentDict]:
def _load(self) -> dict[str, AlignmentDict]:
""" Load the alignments data from the serialized alignments :attr:`file`.
Populates :attr:`_version` with the alignment file's loaded version as well as returning
@ -238,7 +232,7 @@ class Alignments():
"""
return self._io.backup()
def save_video_meta_data(self, pts_time: List[float], keyframes: List[int]) -> None:
def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None:
""" Save video meta data to the alignments file.
If the alignments file does not have an entry for every frame (e.g. if Extract Every N
@ -262,10 +256,10 @@ class Alignments():
logger.info("Saving video meta information to Alignments file")
for idx, pts in enumerate(pts_time):
meta: Dict[str, Union[float, int]] = dict(pts_time=pts, keyframe=idx in keyframes)
meta: dict[str, float | int] = {"pts_time": pts, "keyframe": idx in keyframes}
key = f"{basename}_{idx + 1:06d}.png"
if key not in self.data:
self.data[key] = dict(video_meta=meta, faces=[])
self.data[key] = {"video_meta": meta, "faces": []}
else:
self.data[key]["video_meta"] = meta
@ -285,8 +279,8 @@ class Alignments():
self._io.save()
@classmethod
def _pad_leading_frames(cls, pts_time: List[float], keyframes: List[int]) -> Tuple[List[float],
List[int]]:
def _pad_leading_frames(cls, pts_time: list[float], keyframes: list[int]) -> tuple[list[float],
list[int]]:
""" Calculate the number of frames to pad the video by when the first frame is not
a key frame.
@ -310,7 +304,7 @@ class Alignments():
"""
start_pts = pts_time[0]
logger.debug("Video not cut on keyframe. Start pts: %s", start_pts)
gaps: List[float] = []
gaps: list[float] = []
prev_time = None
for item in pts_time:
if prev_time is not None:
@ -360,7 +354,7 @@ class Alignments():
``True`` if the given frame_name exists within the alignments :attr:`data` and has at
least 1 face associated with it, otherwise ``False``
"""
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = bool(frame_data.get("faces", []))
logger.trace("'%s': %s", frame_name, retval) # type:ignore
return retval
@ -384,7 +378,7 @@ class Alignments():
if not frame_name:
retval = False
else:
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = bool(len(frame_data.get("faces", [])) > 1)
logger.trace("'%s': %s", frame_name, retval) # type:ignore
return retval
@ -414,7 +408,7 @@ class Alignments():
return retval
# << DATA >> #
def get_faces_in_frame(self, frame_name: str) -> List[AlignmentFileDict]:
def get_faces_in_frame(self, frame_name: str) -> list[AlignmentFileDict]:
""" Obtain the faces from :attr:`data` associated with a given frame_name.
Parameters
@ -429,8 +423,8 @@ class Alignments():
The list of face dictionaries that appear within the requested frame_name
"""
logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
return frame_data.get("faces", cast(List[AlignmentFileDict], []))
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
return frame_data.get("faces", T.cast(list[AlignmentFileDict], []))
def _count_faces_in_frame(self, frame_name: str) -> int:
""" Return number of faces that appear within :attr:`data` for the given frame_name.
@ -446,7 +440,7 @@ class Alignments():
int
The number of faces that appear in the given frame_name
"""
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = len(frame_data.get("faces", []))
logger.trace(retval) # type:ignore
return retval
@ -497,7 +491,7 @@ class Alignments():
"""
logger.debug("Adding face to frame_name: '%s'", frame_name)
if frame_name not in self._data:
self._data[frame_name] = dict(faces=[], video_meta={})
self._data[frame_name] = {"faces": [], "video_meta": {}}
self._data[frame_name]["faces"].append(face)
retval = self._count_faces_in_frame(frame_name) - 1
logger.debug("Returning new face index: %s", retval)
@ -520,7 +514,7 @@ class Alignments():
logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name)
self._data[frame_name]["faces"][face_index] = face
def filter_faces(self, filter_dict: Dict[str, List[int]], filter_out: bool = False) -> None:
def filter_faces(self, filter_dict: dict[str, list[int]], filter_out: bool = False) -> None:
""" Remove faces from :attr:`data` based on a given filter list.
Parameters
@ -549,7 +543,7 @@ class Alignments():
del frame_data["faces"][face_idx]
# << GENERATORS >> #
def yield_faces(self) -> Generator[Tuple[str, List[AlignmentFileDict], int, str], None, None]:
def yield_faces(self) -> Generator[tuple[str, list[AlignmentFileDict], int, str], None, None]:
""" Generator to obtain all faces with meta information from :attr:`data`. The results
are yielded by frame.
@ -715,7 +709,7 @@ class _IO():
logger.info("Updating alignments file to version %s", self._version)
self.save()
def load(self) -> Dict[str, AlignmentDict]:
def load(self) -> dict[str, AlignmentDict]:
""" Load the alignments data from the serialized alignments :attr:`file`.
Populates :attr:`_version` with the alignment file's loaded version as well as returning
@ -732,7 +726,7 @@ class _IO():
logger.info("Reading alignments from: '%s'", self._file)
data = self._serializer.load(self._file)
meta = data.get("__meta__", dict(version=1.0))
meta = data.get("__meta__", {"version": 1.0})
self._version = meta["version"]
data = data.get("__data__", data)
logger.debug("Loaded alignments")
@ -743,8 +737,8 @@ class _IO():
the location :attr:`file`. """
logger.debug("Saving alignments")
logger.info("Writing alignments to: '%s'", self._file)
data = dict(__meta__=dict(version=self._version),
__data__=self._alignments.data)
data = {"__meta__": {"version": self._version},
"__data__": self._alignments.data}
self._serializer.save(self._file, data)
logger.debug("Saved alignments")
@ -928,7 +922,7 @@ class _FileStructure(_Updater):
for key, val in self._alignments.data.items():
if not isinstance(val, list):
continue
self._alignments.data[key] = dict(faces=val)
self._alignments.data[key] = {"faces": val}
updated += 1
return updated
@ -1078,11 +1072,11 @@ class _Legacy():
"""
def __init__(self, alignments: Alignments) -> None:
self._alignments = alignments
self._hashes_to_frame: Dict[str, Dict[str, int]] = {}
self._hashes_to_alignment: Dict[str, AlignmentFileDict] = {}
self._hashes_to_frame: dict[str, dict[str, int]] = {}
self._hashes_to_alignment: dict[str, AlignmentFileDict] = {}
@property
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]:
def hashes_to_frame(self) -> dict[str, dict[str, int]]:
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
that the hash corresponds to. The structure of the dictionary is:
@ -1105,7 +1099,7 @@ class _Legacy():
return self._hashes_to_frame
@property
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]:
def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
corresponds to. The structure of the dictionary is:

View file

@ -1,12 +1,11 @@
#!/usr/bin python3
""" Face and landmarks detection for faceswap.py """
from __future__ import annotations
import logging
import sys
import os
import typing as T
from hashlib import sha1
from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from zlib import compress, decompress
import cv2
@ -18,14 +17,10 @@ from .alignments import (Alignments, AlignmentFileDict, MaskAlignmentsFileDict,
PNGHeaderAlignmentsDict, PNGHeaderDict, PNGHeaderSourceDict)
from . import AlignedFace, get_adjusted_center, get_centered_size
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Callable
from .aligned_face import CenteringType
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -85,14 +80,14 @@ class DetectedFace():
dict of {**name** (`str`): :class:`Mask`}.
"""
def __init__(self,
image: Optional[np.ndarray] = None,
left: Optional[int] = None,
width: Optional[int] = None,
top: Optional[int] = None,
height: Optional[int] = None,
landmarks_xy: Optional[np.ndarray] = None,
mask: Optional[Dict[str, "Mask"]] = None,
filename: Optional[str] = None) -> None:
image: np.ndarray | None = None,
left: int | None = None,
width: int | None = None,
top: int | None = None,
height: int | None = None,
landmarks_xy: np.ndarray | None = None,
mask: dict[str, "Mask"] | None = None,
filename: str | None = None) -> None:
logger.trace("Initializing %s: (image: %s, left: %s, width: %s, top: %s, " # type: ignore
"height: %s, landmarks_xy: %s, mask: %s, filename: %s)",
self.__class__.__name__,
@ -104,12 +99,12 @@ class DetectedFace():
self.top = top
self.height = height
self._landmarks_xy = landmarks_xy
self._identity: Dict[str, np.ndarray] = {}
self.thumbnail: Optional[np.ndarray] = None
self._identity: dict[str, np.ndarray] = {}
self.thumbnail: np.ndarray | None = None
self.mask = {} if mask is None else mask
self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None
self._training_masks: tuple[bytes, tuple[int, int, int]] | None = None
self._aligned: Optional[AlignedFace] = None
self._aligned: AlignedFace | None = None
logger.trace("Initialized %s", self.__class__.__name__) # type: ignore
@property
@ -137,7 +132,7 @@ class DetectedFace():
return self.top + self.height
@property
def identity(self) -> Dict[str, np.ndarray]:
def identity(self) -> dict[str, np.ndarray]:
""" dict: Identity mechanism as key, identity embedding as value. """
return self._identity
@ -147,7 +142,7 @@ class DetectedFace():
affine_matrix: np.ndarray,
interpolator: int,
storage_size: int = 128,
storage_centering: "CenteringType" = "face") -> None:
storage_centering: CenteringType = "face") -> None:
""" Add a :class:`Mask` to this detected face
The mask should be the original output from :mod:`plugins.extract.mask`
@ -209,7 +204,7 @@ class DetectedFace():
self._identity[name] = embedding
def get_landmark_mask(self,
area: Literal["eye", "face", "mouth"],
area: T.Literal["eye", "face", "mouth"],
blur_kernel: int,
dilation: int) -> np.ndarray:
""" Add a :class:`LandmarksMask` to this detected face
@ -235,7 +230,7 @@ class DetectedFace():
"""
# TODO Face mask generation from landmarks
logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore
areas = dict(mouth=[slice(48, 60)], eye=[slice(36, 42), slice(42, 48)])
areas = {"mouth": [slice(48, 60)], "eye": [slice(36, 42), slice(42, 48)]}
points = [self.aligned.landmarks[zone]
for zone in areas[area]]
@ -250,7 +245,7 @@ class DetectedFace():
return lmmask.mask
def store_training_masks(self,
masks: List[Optional[np.ndarray]],
masks: list[np.ndarray | None],
delete_masks: bool = False) -> None:
""" Concatenate and compress the given training masks and store for retrieval.
@ -273,7 +268,7 @@ class DetectedFace():
combined = np.concatenate(valid, axis=-1)
self._training_masks = (compress(combined), combined.shape)
def get_training_masks(self) -> Optional[np.ndarray]:
def get_training_masks(self) -> np.ndarray | None:
""" Obtain the decompressed combined training masks.
Returns
@ -312,7 +307,7 @@ class DetectedFace():
return alignment
def from_alignment(self, alignment: AlignmentFileDict,
image: Optional[np.ndarray] = None, with_thumb: bool = False) -> None:
image: np.ndarray | None = None, with_thumb: bool = False) -> None:
""" Set the attributes of this class from an alignments file and optionally load the face
into the ``image`` attribute.
@ -342,7 +337,7 @@ class DetectedFace():
landmarks = alignment["landmarks_xy"]
if not isinstance(landmarks, np.ndarray):
landmarks = np.array(landmarks, dtype="float32")
self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32")
self._identity = {T.cast(T.Literal["vggface2"], k): np.array(v, dtype="float32")
for k, v in alignment.get("identity", {}).items()}
self._landmarks_xy = landmarks.copy()
@ -403,7 +398,7 @@ class DetectedFace():
self._identity = {}
for key, val in alignment.get("identity", {}).items():
assert key in ["vggface2"]
self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32")
self._identity[T.cast(T.Literal["vggface2"], key)] = np.array(val, dtype="float32")
logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore
" height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width,
self.top, self.height, self.landmarks_xy, self.mask,
@ -417,10 +412,10 @@ class DetectedFace():
# <<< Aligned Face methods and properties >>> #
def load_aligned(self,
image: Optional[np.ndarray],
image: np.ndarray | None,
size: int = 256,
dtype: Optional[str] = None,
centering: "CenteringType" = "head",
dtype: str | None = None,
centering: CenteringType = "head",
coverage_ratio: float = 1.0,
force: bool = False,
is_aligned: bool = False,
@ -507,22 +502,22 @@ class Mask():
"""
def __init__(self,
storage_size: int = 128,
storage_centering: "CenteringType" = "face") -> None:
storage_centering: CenteringType = "face") -> None:
logger.trace("Initializing: %s (storage_size: %s, storage_centering: %s)", # type: ignore
self.__class__.__name__, storage_size, storage_centering)
self.stored_size = storage_size
self.stored_centering = storage_centering
self._mask: Optional[bytes] = None
self._affine_matrix: Optional[np.ndarray] = None
self._interpolator: Optional[int] = None
self._mask: bytes | None = None
self._affine_matrix: np.ndarray | None = None
self._interpolator: int | None = None
self._blur_type: Optional[Literal["gaussian", "normalized"]] = None
self._blur_type: T.Literal["gaussian", "normalized"] | None = None
self._blur_passes: int = 0
self._blur_kernel: Union[float, int] = 0
self._blur_kernel: float | int = 0
self._threshold = 0.0
self._sub_crop_size = 0
self._sub_crop_slices: Dict[Literal["in", "out"], List[slice]] = {}
self._sub_crop_slices: dict[T.Literal["in", "out"], list[slice]] = {}
self.set_blur_and_threshold()
logger.trace("Initialized: %s", self.__class__.__name__) # type: ignore
@ -648,7 +643,7 @@ class Mask():
def set_blur_and_threshold(self,
blur_kernel: int = 0,
blur_type: Optional[Literal["gaussian", "normalized"]] = "gaussian",
blur_type: T.Literal["gaussian", "normalized"] | None = "gaussian",
blur_passes: int = 1,
threshold: int = 0) -> None:
""" Set the internal blur kernel and threshold amount for returned masks
@ -679,7 +674,7 @@ class Mask():
def set_sub_crop(self,
source_offset: np.ndarray,
target_offset: np.ndarray,
centering: "CenteringType",
centering: CenteringType,
coverage_ratio: float = 1.0) -> None:
""" Set the internal crop area of the mask to be returned.
@ -831,9 +826,9 @@ class LandmarksMask(Mask):
The amount of dilation to apply to the mask. `0` for none. Default: `0`
"""
def __init__(self,
points: List[np.ndarray],
points: list[np.ndarray],
storage_size: int = 128,
storage_centering: "CenteringType" = "face",
storage_centering: CenteringType = "face",
dilation: int = 0) -> None:
super().__init__(storage_size=storage_size, storage_centering=storage_centering)
self._points = points
@ -907,9 +902,9 @@ class BlurMask(): # pylint:disable=too-few-public-methods
(128, 128, 1)
"""
def __init__(self,
blur_type: Literal["gaussian", "normalized"],
blur_type: T.Literal["gaussian", "normalized"],
mask: np.ndarray,
kernel: Union[int, float],
kernel: int | float,
is_ratio: bool = False,
passes: int = 1) -> None:
logger.trace("Initializing %s: (blur_type: '%s', mask_shape: %s, " # type: ignore
@ -943,33 +938,30 @@ class BlurMask(): # pylint:disable=too-few-public-methods
def _multipass_factor(self) -> float:
""" For multiple passes the kernel must be scaled down. This value is
different for box filter and gaussian """
factor = dict(gaussian=0.8, normalized=0.5)
factor = {"gaussian": 0.8, "normalized": 0.5}
return factor[self._blur_type]
@property
def _sigma(self) -> Literal[0]:
def _sigma(self) -> T.Literal[0]:
""" int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """
return 0
@property
def _func_mapping(self) -> Dict[Literal["gaussian", "normalized"], Callable]:
def _func_mapping(self) -> dict[T.Literal["gaussian", "normalized"], Callable]:
""" dict: :attr:`_blur_type` mapped to cv2 Function name. """
return dict(gaussian=cv2.GaussianBlur, # pylint: disable = no-member
normalized=cv2.blur) # pylint: disable = no-member
return {"gaussian": cv2.GaussianBlur, "normalized": cv2.blur}
@property
def _kwarg_requirements(self) -> Dict[Literal["gaussian", "normalized"], List[str]]:
def _kwarg_requirements(self) -> dict[T.Literal["gaussian", "normalized"], list[str]]:
""" dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """
return dict(gaussian=["ksize", "sigmaX"],
normalized=["ksize"])
return {"gaussian": ['ksize', 'sigmaX'], "normalized": ['ksize']}
@property
def _kwarg_mapping(self) -> Dict[str, Union[int, Tuple[int, int]]]:
def _kwarg_mapping(self) -> dict[str, int | tuple[int, int]]:
""" dict: cv2 function keyword arguments mapped to their parameters. """
return dict(ksize=self._kernel_size,
sigmaX=self._sigma)
return {"ksize": self._kernel_size, "sigmaX": self._sigma}
def _get_kernel_size(self, kernel: Union[int, float], is_ratio: bool) -> int:
def _get_kernel_size(self, kernel: int | float, is_ratio: bool) -> int:
""" Set the kernel size to absolute value.
If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and
@ -999,7 +991,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
return kernel_size
@staticmethod
def _get_kernel_tuple(kernel_size: int) -> Tuple[int, int]:
def _get_kernel_tuple(kernel_size: int) -> tuple[int, int]:
""" Make sure kernel_size is odd and return it as a tuple.
Parameters
@ -1017,7 +1009,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
logger.trace(retval) # type: ignore
return retval
def _get_kwargs(self) -> Dict[str, Union[int, Tuple[int, int]]]:
def _get_kwargs(self) -> dict[str, int | tuple[int, int]]:
""" dict: the valid keyword arguments for the requested :attr:`_blur_type` """
retval = {kword: self._kwarg_mapping[kword]
for kword in self._kwarg_requirements[self._blur_type]}
@ -1025,11 +1017,11 @@ class BlurMask(): # pylint:disable=too-few-public-methods
return retval
_HASHES_SEEN: Dict[str, Dict[str, int]] = {}
_HASHES_SEEN: dict[str, dict[str, int]] = {}
def update_legacy_png_header(filename: str, alignments: Alignments
) -> Optional[PNGHeaderDict]:
) -> PNGHeaderDict | None:
""" Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for
the face in the png exif header for the given filename with the given alignment data.

View file

@ -7,7 +7,7 @@ as well as adding a mechanism for indicating to the GUI how specific options sho
import argparse
import os
from typing import Any, List, Optional, Tuple, Union
import typing as T
# << FILE HANDLING >>
@ -69,7 +69,7 @@ class FileFullPaths(_FullPaths):
>>> filetypes="video))"
"""
# pylint: disable=too-few-public-methods
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.filetypes = filetypes
@ -111,7 +111,7 @@ class FilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods
>>> filetypes="image",
>>> nargs="+"))
"""
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
if kwargs.get("nargs", None) is None:
opt = kwargs["option_strings"]
raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}")
@ -250,8 +250,8 @@ class ContextFullPaths(FileFullPaths):
# pylint: disable=too-few-public-methods, too-many-arguments
def __init__(self,
*args,
filetypes: Optional[str] = None,
action_option: Optional[str] = None,
filetypes: str | None = None,
action_option: str | None = None,
**kwargs) -> None:
opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None:
@ -263,7 +263,7 @@ class ContextFullPaths(FileFullPaths):
super().__init__(*args, filetypes=filetypes, **kwargs)
self.action_option = action_option
def _get_kwargs(self) -> List[Tuple[str, Any]]:
def _get_kwargs(self) -> list[tuple[str, T.Any]]:
names = ["option_strings",
"dest",
"nargs",
@ -382,8 +382,8 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
"""
def __init__(self,
*args,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
rounding: Optional[int] = None,
min_max: tuple[int, int] | tuple[float, float] | None = None,
rounding: int | None = None,
**kwargs) -> None:
opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None:
@ -401,7 +401,7 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
self.min_max = min_max
self.rounding = rounding
def _get_kwargs(self) -> List[Tuple[str, Any]]:
def _get_kwargs(self) -> list[tuple[str, T.Any]]:
names = ["option_strings",
"dest",
"nargs",

View file

@ -8,8 +8,7 @@ import logging
import re
import sys
import textwrap
from typing import Any, Dict, List, NoReturn, Optional
import typing as T
from lib.utils import get_backend
from lib.gpu_stats import GPUStats
@ -30,7 +29,7 @@ _ = _LANG.gettext
class FullHelpArgumentParser(argparse.ArgumentParser):
""" Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """
def error(self, message: str) -> NoReturn:
def error(self, message: str) -> T.NoReturn:
self.print_help(sys.stderr)
self.exit(2, f"{self.prog}: error: {message}\n")
@ -51,11 +50,11 @@ class SmartFormatter(argparse.HelpFormatter):
prog: str,
indent_increment: int = 2,
max_help_position: int = 24,
width: Optional[int] = None) -> None:
width: int | None = None) -> None:
super().__init__(prog, indent_increment, max_help_position, width)
self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII)
def _split_lines(self, text: str, width: int) -> List[str]:
def _split_lines(self, text: str, width: int) -> list[str]:
""" Split the given text by the given display width.
If the text is not prefixed with "R|" then the standard
@ -138,7 +137,7 @@ class FaceSwapArgs():
return ""
@staticmethod
def get_argument_list() -> List[Dict[str, Any]]:
def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for the current command.
The argument list should be a list of dictionaries pertaining to each option for a command.
@ -152,11 +151,11 @@ class FaceSwapArgs():
list
The list of command line options for the given command
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
return argument_list
@staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]:
def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the optional argument list for the current command.
The optional arguments list is not always required, but is used when there are shared
@ -167,11 +166,11 @@ class FaceSwapArgs():
list
The list of optional command line options for the given command
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
return argument_list
@staticmethod
def _get_global_arguments() -> List[Dict[str, Any]]:
def _get_global_arguments() -> list[dict[str, T.Any]]:
""" Returns the global Arguments list that are required for ALL commands in Faceswap.
This method should NOT be overridden.
@ -181,7 +180,7 @@ class FaceSwapArgs():
list
The list of global command line options for all Faceswap commands.
"""
global_args: List[Dict[str, Any]] = []
global_args: list[dict[str, T.Any]] = []
if _GPUS:
global_args.append(dict(
opts=("-X", "--exclude-gpus"),
@ -302,7 +301,7 @@ class ExtractConvertArgs(FaceSwapArgs):
"""
@staticmethod
def get_argument_list() -> List[Dict[str, Any]]:
def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for shared Extract and Convert arguments.
Returns
@ -310,7 +309,7 @@ class ExtractConvertArgs(FaceSwapArgs):
list
The list of command line options for the given Extract and Convert
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict(
opts=("-i", "--input-dir"),
action=DirOrFileFullPaths,
@ -362,7 +361,7 @@ class ExtractArgs(ExtractConvertArgs):
"Extraction plugins can be configured in the 'Settings' Menu")
@staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]:
def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the argument list unique to the Extract command.
Returns
@ -377,7 +376,7 @@ class ExtractArgs(ExtractConvertArgs):
default_detector = "s3fd"
default_aligner = "fan"
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict(
opts=("-b", "--batch-mode"),
action="store_true",
@ -658,7 +657,7 @@ class ConvertArgs(ExtractConvertArgs):
"Conversion plugins can be configured in the 'Settings' Menu")
@staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]:
def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the argument list unique to the Convert command.
Returns
@ -667,7 +666,7 @@ class ConvertArgs(ExtractConvertArgs):
The list of optional command line options for the Convert command
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict(
opts=("-ref", "--reference-video"),
action=FileFullPaths,
@ -915,7 +914,7 @@ class TrainArgs(FaceSwapArgs):
"Model plugins can be configured in the 'Settings' Menu")
@staticmethod
def get_argument_list() -> List[Dict[str, Any]]:
def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for Train arguments.
Returns
@ -923,7 +922,7 @@ class TrainArgs(FaceSwapArgs):
list
The list of command line options for training
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict(
opts=("-A", "--input-A"),
action=DirFullPaths,
@ -1180,7 +1179,7 @@ class GuiArgs(FaceSwapArgs):
""" Creates the command line arguments for the GUI. """
@staticmethod
def get_argument_list() -> List[Dict[str, Any]]:
def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for GUI arguments.
Returns
@ -1188,7 +1187,7 @@ class GuiArgs(FaceSwapArgs):
list
The list of command line options for the GUI
"""
argument_list: List[Dict[str, Any]] = []
argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict(
opts=("-d", "--debug"),
action="store_true",

View file

@ -1,20 +1,22 @@
#!/usr/bin/env python3
""" Launches the correct script with the given Command Line Arguments """
from __future__ import annotations
import logging
import os
import platform
import sys
import typing as T
from importlib import import_module
from typing import Callable, TYPE_CHECKING
from lib.gpu_stats import set_exclude_devices, GPUStats
from lib.logger import crash_log, log_setup
from lib.utils import (FaceswapError, get_backend, get_tf_version,
safe_shutdown, set_backend, set_system_verbosity)
if TYPE_CHECKING:
if T.TYPE_CHECKING:
import argparse
from collections.abc import Callable
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -99,8 +101,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
FaceswapError
If Tensorflow is not found, or is not between versions 2.4 and 2.9
"""
directml_ver = rocm_ver = (2, 10)
min_ver = (2, 7)
min_ver = (2, 10)
max_ver = (2, 10)
try:
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
@ -120,7 +121,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
self._handle_import_error(msg)
tf_ver = get_tf_version()
backend = get_backend()
if tf_ver < min_ver:
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
f"{tf_ver} installed. Please upgrade Tensorflow.")
@ -129,14 +129,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
f"{tf_ver} installed. Please downgrade Tensorflow.")
self._handle_import_error(msg)
if backend == "directml" and tf_ver != directml_ver:
msg = (f"The supported Tensorflow version for DirectML cards is {directml_ver} but "
f"you have version {tf_ver} installed. Please install the correct version.")
self._handle_import_error(msg)
if backend == "rocm" and tf_ver != rocm_ver:
msg = (f"The supported Tensorflow version for ROCm cards is {rocm_ver} but "
f"you have version {tf_ver} installed. Please install the correct version.")
self._handle_import_error(msg)
logger.debug("Installed Tensorflow Version: %s", tf_ver)
@classmethod
@ -209,7 +201,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
"See https://support.apple.com/en-gb/HT201341")
raise FaceswapError("No display detected. GUI mode has been disabled.")
def execute_script(self, arguments: "argparse.Namespace") -> None:
def execute_script(self, arguments: argparse.Namespace) -> None:
""" Performs final set up and launches the requested :attr:`_command` with the given
command line arguments.
@ -250,7 +242,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
finally:
safe_shutdown(got_error=not success)
def _configure_backend(self, arguments: "argparse.Namespace") -> None:
def _configure_backend(self, arguments: argparse.Namespace) -> None:
""" Configure the backend.
Exclude any GPUs for use by Faceswap when requested.

View file

@ -13,7 +13,6 @@ from collections import OrderedDict
from configparser import ConfigParser
from dataclasses import dataclass
from importlib import import_module
from typing import Dict, List, Optional, Tuple, Union
from lib.utils import full_path_split
@ -21,16 +20,11 @@ from lib.utils import full_path_split
_LANG = gettext.translation("lib.config", localedir="locales", fallback=True)
_ = _LANG.gettext
# Can't type OrderedDict fully on Python 3.8 or lower
if sys.version_info < (3, 9):
OrderedDictSectionType = OrderedDict
OrderedDictItemType = OrderedDict
else:
OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
OrderedDictItemType = OrderedDict[str, "ConfigItem"]
OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
OrderedDictItemType = OrderedDict[str, "ConfigItem"]
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
ConfigValueType = Union[bool, int, float, List[str], str, None]
ConfigValueType = bool | int | float | list[str] | str | None
@dataclass
@ -60,11 +54,11 @@ class ConfigItem:
helptext: str
datatype: type
rounding: int
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]]
choices: Union[str, List[str]]
min_max: tuple[int, int] | tuple[float, float] | None
choices: str | list[str]
gui_radio: bool
fixed: bool
group: Optional[str]
group: str | None
@dataclass
@ -84,7 +78,7 @@ class ConfigSection:
class FaceswapConfig():
""" Config Items """
def __init__(self, section: Optional[str], configfile: Optional[str] = None) -> None:
def __init__(self, section: str | None, configfile: str | None = None) -> None:
""" Init Configuration
Parameters
@ -106,11 +100,11 @@ class FaceswapConfig():
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def changeable_items(self) -> Dict[str, ConfigValueType]:
def changeable_items(self) -> dict[str, ConfigValueType]:
""" Training only.
Return a dict of config items with their set values for items
that can be altered after the model has been created """
retval: Dict[str, ConfigValueType] = {}
retval: dict[str, ConfigValueType] = {}
sections = [sect for sect in self.config.sections() if sect.startswith("global")]
all_sections = sections if self.section is None else sections + [self.section]
for sect in all_sections:
@ -189,10 +183,10 @@ class FaceswapConfig():
logger.debug("Added defaults: %s", section)
@property
def config_dict(self) -> Dict[str, ConfigValueType]:
def config_dict(self) -> dict[str, ConfigValueType]:
""" dict: Collate global options and requested section into a dictionary with the correct
data types """
conf: Dict[str, ConfigValueType] = {}
conf: dict[str, ConfigValueType] = {}
sections = [sect for sect in self.config.sections() if sect.startswith("global")]
if self.section is not None:
sections.append(self.section)
@ -240,7 +234,7 @@ class FaceswapConfig():
logger.debug("Returning item: (type: %s, value: %s)", datatype, retval)
return retval
def _parse_list(self, section: str, option: str) -> List[str]:
def _parse_list(self, section: str, option: str) -> list[str]:
""" Parse options that are stored as lists in the config file. These can be space or
comma-separated items in the config file. They will be returned as a list of strings,
regardless of what the final data type should be, so conversion from strings to other
@ -268,7 +262,7 @@ class FaceswapConfig():
raw_option, retval, section, option)
return retval
def _get_config_file(self, configfile: Optional[str]) -> str:
def _get_config_file(self, configfile: str | None) -> str:
""" Return the config file from the calling folder or the provided file
Parameters
@ -309,17 +303,17 @@ class FaceswapConfig():
self.defaults[title] = ConfigSection(helptext=info, items=OrderedDict())
def add_item(self,
section: Optional[str] = None,
title: Optional[str] = None,
section: str | None = None,
title: str | None = None,
datatype: type = str,
default: ConfigValueType = None,
info: Optional[str] = None,
rounding: Optional[int] = None,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
choices: Optional[Union[str, List[str]]] = None,
info: str | None = None,
rounding: int | None = None,
min_max: tuple[int, int] | tuple[float, float] | None = None,
choices: str | list[str] | None = None,
gui_radio: bool = False,
fixed: bool = True,
group: Optional[str] = None) -> None:
group: str | None = None) -> None:
""" Add a default item to a config section
For int or float values, rounding and min_max must be set
@ -382,10 +376,10 @@ class FaceswapConfig():
@classmethod
def _expand_helptext(cls,
helptext: str,
choices: Union[str, List[str]],
choices: str | list[str],
default: ConfigValueType,
datatype: type,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]],
min_max: tuple[int, int] | tuple[float, float] | None,
fixed: bool) -> str:
""" Add extra helptext info from parameters """
helptext += "\n"
@ -437,7 +431,7 @@ class FaceswapConfig():
def insert_config_section(self,
section: str,
helptext: str,
config: Optional[ConfigParser] = None) -> None:
config: ConfigParser | None = None) -> None:
""" Insert a section into the config
Parameters
@ -464,7 +458,7 @@ class FaceswapConfig():
item: str,
default: ConfigValueType,
option: ConfigItem,
config: Optional[ConfigParser] = None) -> None:
config: ConfigParser | None = None) -> None:
""" Insert an item into a config section
Parameters

View file

@ -1,23 +1,18 @@
#!/usr/bin/env python3
""" Converter for Faceswap """
from __future__ import annotations
import logging
import sys
import typing as T
from dataclasses import dataclass
from typing import Callable, cast, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2
import numpy as np
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Callable
from lib.align.aligned_face import AlignedFace, CenteringType
from lib.align.detected_face import DetectedFace
from lib.config import FaceswapConfig
@ -46,10 +41,10 @@ class Adjustments:
sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional
The selected mask processing plugin. Default: `None`
"""
color: Optional["ColorAdjust"] = None
mask: Optional["MaskAdjust"] = None
seamless: Optional["SeamlessAdjust"] = None
sharpening: Optional["ScalingAdjust"] = None
color: ColorAdjust | None = None
mask: MaskAdjust | None = None
seamless: SeamlessAdjust | None = None
sharpening: ScalingAdjust | None = None
class Converter():
@ -81,11 +76,11 @@ class Converter():
def __init__(self,
output_size: int,
coverage_ratio: float,
centering: "CenteringType",
centering: CenteringType,
draw_transparent: bool,
pre_encode: Optional[Callable[[np.ndarray], List[bytes]]],
arguments: "Namespace",
configfile: Optional[str] = None) -> None:
pre_encode: Callable[[np.ndarray], list[bytes]] | None,
arguments: Namespace,
configfile: str | None = None) -> None:
logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, "
"draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)",
self.__class__.__name__, output_size, coverage_ratio, centering,
@ -105,12 +100,12 @@ class Converter():
logger.debug("Initialized %s", self.__class__.__name__)
@property
def cli_arguments(self) -> "Namespace":
def cli_arguments(self) -> Namespace:
""":class:`argparse.Namespace`: The command line arguments passed to the convert
process """
return self._args
def reinitialize(self, config: "FaceswapConfig") -> None:
def reinitialize(self, config: FaceswapConfig) -> None:
""" Reinitialize this :class:`Converter`.
Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the
@ -127,7 +122,7 @@ class Converter():
logger.debug("Reinitialized converter")
def _load_plugins(self,
config: Optional["FaceswapConfig"] = None,
config: FaceswapConfig | None = None,
disable_logging: bool = False) -> None:
""" Load the requested adjustment plugins.
@ -169,7 +164,7 @@ class Converter():
self._adjustments.sharpening = sharpening
logger.debug("Loaded plugins: %s", self._adjustments)
def process(self, in_queue: "EventQueue", out_queue: "EventQueue"):
def process(self, in_queue: EventQueue, out_queue: EventQueue):
""" Main convert process.
Takes items from the in queue, runs the relevant adjustments, patches faces to final frame
@ -188,7 +183,7 @@ class Converter():
in_queue, out_queue)
log_once = False
while True:
inbound: Union[Literal["EOF"], "ConvertItem", List["ConvertItem"]] = in_queue.get()
inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get()
if inbound == "EOF":
logger.debug("EOF Received")
logger.debug("Patch queue finished")
@ -218,7 +213,7 @@ class Converter():
out_queue.put((item.inbound.filename, image))
logger.debug("Completed convert process")
def _patch_image(self, predicted: "ConvertItem") -> Union[np.ndarray, List[bytes]]:
def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]:
""" Patch a swapped face onto a frame.
Run selected adjustments and swap the faces in a frame.
@ -246,15 +241,15 @@ class Converter():
out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe')
if self._writer_pre_encode is None:
retval: Union[np.ndarray, List[bytes]] = patched_face
retval: np.ndarray | list[bytes] = patched_face
else:
retval = self._writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore
return retval
def _get_new_image(self,
predicted: "ConvertItem",
frame_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
predicted: ConvertItem,
frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]:
""" Get the new face from the predictor and apply pre-warp manipulations.
Applies any requested adjustments to the raw output of the Faceswap model
@ -308,9 +303,9 @@ class Converter():
def _pre_warp_adjustments(self,
new_face: np.ndarray,
detected_face: "DetectedFace",
reference_face: "AlignedFace",
predicted_mask: Optional[np.ndarray]) -> np.ndarray:
detected_face: DetectedFace,
reference_face: AlignedFace,
predicted_mask: np.ndarray | None) -> np.ndarray:
""" Run any requested adjustments that can be performed on the raw output from the Faceswap
model.
@ -337,7 +332,7 @@ class Converter():
"""
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore
new_face.shape, predicted_mask.shape if predicted_mask is not None else None)
old_face = cast(np.ndarray, reference_face.face)[..., :3] / 255.0
old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0
new_face, raw_mask = self._get_image_mask(new_face,
detected_face,
predicted_mask,
@ -351,9 +346,9 @@ class Converter():
def _get_image_mask(self,
new_face: np.ndarray,
detected_face: "DetectedFace",
predicted_mask: Optional[np.ndarray],
reference_face: "AlignedFace") -> Tuple[np.ndarray, np.ndarray]:
detected_face: DetectedFace,
predicted_mask: np.ndarray | None,
reference_face: AlignedFace) -> tuple[np.ndarray, np.ndarray]:
""" Return any selected image mask
Places the requested mask into the new face's Alpha channel.

View file

@ -5,11 +5,10 @@ from the :class:`_GPUStats` class contained here. """
import logging
from dataclasses import dataclass
from typing import List, Optional
from lib.utils import get_backend
_EXCLUDE_DEVICES: List[int] = []
_EXCLUDE_DEVICES: list[int] = []
@dataclass
@ -29,11 +28,11 @@ class GPUInfo():
devices_active: list[int]
List of integers representing the indices of the active GPU devices.
"""
vram: List[int]
vram_free: List[int]
vram: list[int]
vram_free: list[int]
driver: str
devices: List[str]
devices_active: List[int]
devices: list[str]
devices_active: list[int]
@dataclass
@ -57,7 +56,7 @@ class BiggestGPUInfo():
total: float
def set_exclude_devices(devices: List[int]) -> None:
def set_exclude_devices(devices: list[int]) -> None:
""" Add any explicitly selected GPU devices to the global list of devices to be excluded
from use by Faceswap.
@ -89,19 +88,19 @@ class _GPUStats():
def __init__(self, log: bool = True) -> None:
# Logger is held internally, as we don't want to log when obtaining system stats on crash
# or when querying the backend for command line options
self._logger: Optional[logging.Logger] = logging.getLogger(__name__) if log else None
self._logger: logging.Logger | None = logging.getLogger(__name__) if log else None
self._log("debug", f"Initializing {self.__class__.__name__}")
self._is_initialized = False
self._initialize()
self._device_count: int = self._get_device_count()
self._active_devices: List[int] = self._get_active_devices()
self._active_devices: list[int] = self._get_active_devices()
self._handles: list = self._get_handles()
self._driver: str = self._get_driver()
self._device_names: List[str] = self._get_device_names()
self._vram: List[int] = self._get_vram()
self._vram_free: List[int] = self._get_free_vram()
self._device_names: list[str] = self._get_device_names()
self._vram: list[int] = self._get_vram()
self._vram_free: list[int] = self._get_free_vram()
if get_backend() != "cpu" and not self._active_devices:
self._log("warning", "No GPU detected")
@ -115,7 +114,7 @@ class _GPUStats():
return self._device_count
@property
def cli_devices(self) -> List[str]:
def cli_devices(self) -> list[str]:
""" list[str]: Formatted index: name text string for each GPU """
return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)]
@ -167,7 +166,7 @@ class _GPUStats():
"""
raise NotImplementedError()
def _get_active_devices(self) -> List[int]:
def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded in
the command line arguments).
@ -204,7 +203,7 @@ class _GPUStats():
"""
raise NotImplementedError()
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Override to obtain the names of all connected GPUs. The quality of this information
depends on the backend and OS being used, but it should be sufficient for identifying
cards.
@ -217,7 +216,7 @@ class _GPUStats():
"""
raise NotImplementedError()
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Override to obtain the total VRAM in Megabytes for each connected GPU.
Returns
@ -228,7 +227,7 @@ class _GPUStats():
"""
raise NotImplementedError()
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Override to obtain the amount of VRAM that is available, in Megabytes, for each
connected GPU.

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """
from typing import Any, List
import typing as T
import os
import psutil
@ -35,7 +35,7 @@ class AppleSiliconStats(_GPUStats):
"""
def __init__(self, log: bool = True) -> None:
# Following attribute set in :func:``_initialize``
self._tf_devices: List[Any] = []
self._tf_devices: list[T.Any] = []
super().__init__(log=log)
@ -142,7 +142,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}")
return driver
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of available Apple Silicon SoC(s) as identified in
:attr:`_handles`.
@ -155,7 +155,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in
:attr:`_handles`.
@ -175,7 +175,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"SoC RAM: {vram}")
return vram
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each available Apple
Silicon SoC.

View file

@ -1,9 +1,5 @@
#!/usr/bin/env python3
""" Dummy functions for running faceswap on CPU. """
from typing import List
from ._base import _GPUStats
@ -65,7 +61,7 @@ class CPUStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}")
return driver
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns
@ -73,11 +69,11 @@ class CPUStats(_GPUStats):
list
An empty list for CPU backends
"""
names: List[str] = []
names: list[str] = []
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the RAM in Megabytes for the running system.
Returns
@ -85,11 +81,11 @@ class CPUStats(_GPUStats):
list
An empty list for CPU backends
"""
vram: List[int] = []
vram: list[int] = []
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of RAM that is available, in Megabytes, for the running system.
Returns
@ -97,6 +93,6 @@ class CPUStats(_GPUStats):
list
An empty list for CPU backends
"""
vram: List[int] = []
vram: list[int] = []
self._log("debug", f"GPU VRAM free: {vram}")
return vram

View file

@ -1,19 +1,23 @@
#!/usr/bin/env python3
""" Collects and returns Information on DirectX 12 hardware devices for DirectML. """
from __future__ import annotations
import os
import sys
import typing as T
assert sys.platform == "win32"
import ctypes
from ctypes import POINTER, Structure, windll
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Any, Callable, cast, List
from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error
from ._base import _GPUStats
if T.TYPE_CHECKING:
from collections.abc import Callable
# Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types
# We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we
# attach it directly and suck up the typing errors.
@ -314,7 +318,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._adapters = self._get_adapters()
self._devices = self._process_adapters()
self._valid_adaptors: List[Device] = []
self._valid_adaptors: list[Device] = []
self._log("debug", f"Initialized {self.__class__.__name__}")
def _get_factory(self) -> ctypes._Pointer:
@ -334,12 +338,12 @@ class Adapters(): # pylint:disable=too-few-public-methods
factory_func.restype = HRESULT
handle = ctypes.c_void_p(0)
factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access
retval = ctypes.POINTER(IDXGIFactory6)(cast(IDXGIFactory6, handle.value))
retval = ctypes.POINTER(IDXGIFactory6)(T.cast(IDXGIFactory6, handle.value))
self._log("debug", f"factory: {retval}")
return retval
@property
def valid_adapters(self) -> List[Device]:
def valid_adapters(self) -> list[Device]:
""" list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """
if self._valid_adaptors:
return self._valid_adaptors
@ -354,7 +358,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._log("debug", f"valid_adaptors: {self._valid_adaptors}")
return self._valid_adaptors
def _get_adapters(self) -> List[ctypes._Pointer]:
def _get_adapters(self) -> list[ctypes._Pointer]:
""" Obtain DirectX 12 supporting hardware adapter objects and add a Device class for
obtaining details
@ -376,7 +380,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
if success != 0:
raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: "
f"{hex(ctypes.c_ulong(success).value)}")
adapter = POINTER(IDXGIAdapter3)(cast(IDXGIAdapter3, handle.value))
adapter = POINTER(IDXGIAdapter3)(T.cast(IDXGIAdapter3, handle.value))
self._log("debug", f"found adapter: {adapter}")
retval.append(adapter)
except COMError as err:
@ -392,7 +396,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._log("debug", f"adapters: {retval}")
return retval
def _query_adapter(self, func: Callable[[Any], Any], *args: Any) -> None:
def _query_adapter(self, func: Callable[[T.Any], T.Any], *args: T.Any) -> None:
""" Query an adapter function, logging if the HRESULT is not a success
Parameters
@ -430,7 +434,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
LookupGUID.ID3D12Device)
return success in (0, 1)
def _process_adapters(self) -> List[Device]:
def _process_adapters(self) -> list[Device]:
""" Process the adapters to add discovered information.
Returns
@ -485,21 +489,21 @@ class DirectML(_GPUStats):
Default: ``True``
"""
def __init__(self, log: bool = True) -> None:
self._devices: List[Device] = []
self._devices: list[Device] = []
super().__init__(log=log)
@property
def _all_vram(self) -> List[int]:
def _all_vram(self) -> list[int]:
""" list: The VRAM of each GPU device that the DX API has discovered. """
return [int(device.description.DedicatedVideoMemory / (1024 * 1024))
for device in self._devices]
@property
def names(self) -> List[str]:
def names(self) -> list[str]:
""" list: The name of each GPU device that the DX API has discovered. """
return [device.description.Description for device in self._devices]
def _get_active_devices(self) -> List[int]:
def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments).
@ -517,7 +521,7 @@ class DirectML(_GPUStats):
self._log("debug", f"Active GPU Devices: {devices}")
return devices
def _get_devices(self) -> List[Device]:
def _get_devices(self) -> list[Device]:
""" Obtain all detected DX API devices.
Returns
@ -582,7 +586,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU Drivers: {drivers}")
return drivers
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns
@ -594,7 +598,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in
:attr:`_handles`.
@ -607,7 +611,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX
12 supporting GPU.

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3
""" Collects and returns Information on available Nvidia GPUs. """
import os
from typing import List
import pynvml
@ -83,7 +82,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Device count: {retval}")
return retval
def _get_active_devices(self) -> List[int]:
def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments).
@ -130,7 +129,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}")
return driver
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
Returns
@ -143,7 +142,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`.
@ -157,7 +156,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU.

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python3
""" Collects and returns Information on available Nvidia GPUs connected to Apple Macs. """
from typing import List
import pynvx
from lib.utils import FaceswapError
@ -92,7 +90,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}")
return driver
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
Returns
@ -105,7 +103,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}")
return names
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`.
@ -120,7 +118,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}")
return vram
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU.

View file

@ -10,7 +10,6 @@ It is a good starting point but may need to be refined over time
import os
import re
from subprocess import run
from typing import List
from ._base import _GPUStats
@ -221,7 +220,7 @@ class ROCm(_GPUStats):
"""
def __init__(self, log: bool = True) -> None:
self._vendor_id = "0x1002" # AMD VendorID
self._sysfs_paths: List[str] = []
self._sysfs_paths: list[str] = []
super().__init__(log=log)
def _from_sysfs_file(self, path: str) -> str:
@ -249,7 +248,7 @@ class ROCm(_GPUStats):
val = ""
return val
def _get_sysfs_paths(self) -> List[str]:
def _get_sysfs_paths(self) -> list[str]:
""" Obtain a list of sysfs paths to AMD branded GPUs connected to the system
Returns
@ -259,7 +258,7 @@ class ROCm(_GPUStats):
"""
base_dir = "/sys/class/drm/"
retval: List[str] = []
retval: list[str] = []
if not os.path.exists(base_dir):
self._log("warning", f"sysfs not found at '{base_dir}'")
return retval
@ -347,7 +346,7 @@ class ROCm(_GPUStats):
self._log("debug", f"GPU Drivers: {retval}")
return retval
def _get_device_names(self) -> List[str]:
def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns
@ -383,7 +382,7 @@ class ROCm(_GPUStats):
self._log("debug", f"Device names: {retval}")
return retval
def _get_active_devices(self) -> List[int]:
def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments).
@ -401,7 +400,7 @@ class ROCm(_GPUStats):
self._log("debug", f"Active GPU Devices: {devices}")
return devices
def _get_vram(self) -> List[int]:
def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected AMD GPU as identified in
:attr:`_handles`.
@ -423,7 +422,7 @@ class ROCm(_GPUStats):
self._log("debug", f"GPU VRAM: {retval}")
return retval
def _get_free_vram(self) -> List[int]:
def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD
GPU.

View file

@ -1,13 +1,12 @@
#!/usr/bin/env python3
""" Handles the loading and collation of events from Tensorflow event log files. """
from __future__ import annotations
import logging
import os
import sys
import typing as T
import zlib
from dataclasses import dataclass, field
from typing import Any, cast, Dict, Iterator, Generator, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
@ -17,11 +16,8 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
from lib.serializer import get_serializer
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from collections.abc import Generator, Iterator
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -38,7 +34,7 @@ class EventData:
The loss values collected for A and B sides for the event step
"""
timestamp: float = 0.0
loss: List[float] = field(default_factory=list)
loss: list[float] = field(default_factory=list)
class _LogFiles():
@ -56,11 +52,11 @@ class _LogFiles():
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def session_ids(self) -> List[int]:
def session_ids(self) -> list[int]:
""" list[int]: Sorted list of `ints` of available session ids. """
return list(sorted(self._filenames))
def _get_log_filenames(self) -> Dict[int, str]:
def _get_log_filenames(self) -> dict[int, str]:
""" Get the Tensorflow event filenames for all existing sessions.
Returns
@ -69,7 +65,7 @@ class _LogFiles():
The full path of each log file for each training session id that has been run
"""
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
retval: Dict[int, str] = {}
retval: dict[int, str] = {}
for dirpath, _, filenames in os.walk(self._logs_folder):
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
continue
@ -82,7 +78,7 @@ class _LogFiles():
return retval
@classmethod
def _get_session_id(cls, folder: str) -> Optional[int]:
def _get_session_id(cls, folder: str) -> int | None:
""" Obtain the session id for the given folder.
Parameters
@ -103,7 +99,7 @@ class _LogFiles():
return retval
@classmethod
def _get_log_filename(cls, folder: str, filenames: List[str]) -> str:
def _get_log_filename(cls, folder: str, filenames: list[str]) -> str:
""" Obtain the session log file for the given folder. If multiple log files exist for the
given folder, then the most recent log file is used, as earlier files are assumed to be
obsolete.
@ -161,10 +157,10 @@ class _CacheData():
loss: :class:`np.ndarray`
The loss values collected for A and B sides for the session
"""
def __init__(self, labels: List[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
self.labels = labels
self._loss = zlib.compress(cast(bytes, loss))
self._timestamps = zlib.compress(cast(bytes, timestamps))
self._loss = zlib.compress(T.cast(bytes, loss))
self._timestamps = zlib.compress(T.cast(bytes, timestamps))
self._timestamps_shape = timestamps.shape
self._loss_shape = loss.shape
@ -192,8 +188,8 @@ class _CacheData():
timestamps: :class:`numpy.ndarray`
The latest timestamps to add to the cache
"""
new_buffer: List[bytes] = []
new_shapes: List[Tuple[int, ...]] = []
new_buffer: list[bytes] = []
new_shapes: list[tuple[int, ...]] = []
for data, buffer, dtype, shape in zip([timestamps, loss],
[self._timestamps, self._loss],
["float64", "float32"],
@ -220,9 +216,9 @@ class _Cache():
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
self._data: Dict[int, _CacheData] = {}
self._carry_over: Dict[int, EventData] = {}
self._loss_labels: List[str] = []
self._data: dict[int, _CacheData] = {}
self._carry_over: dict[int, EventData] = {}
self._loss_labels: list[str] = []
logger.debug("Initialized: %s", self.__class__.__name__)
def is_cached(self, session_id: int) -> bool:
@ -242,8 +238,8 @@ class _Cache():
def cache_data(self,
session_id: int,
data: Dict[int, EventData],
labels: List[str],
data: dict[int, EventData],
labels: list[str],
is_live: bool = False) -> None:
""" Add a full session's worth of event data to :attr:`_data`.
@ -278,8 +274,8 @@ class _Cache():
self._add_latest_live(session_id, loss, timestamps)
def _to_numpy(self,
data: Dict[int, EventData],
is_live: bool) -> Tuple[np.ndarray, np.ndarray]:
data: dict[int, EventData],
is_live: bool) -> tuple[np.ndarray, np.ndarray]:
""" Extract each individual step data into separate numpy arrays for loss and timestamps.
Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays
@ -333,7 +329,7 @@ class _Cache():
return n_times, n_loss
def _collect_carry_over(self, data: Dict[int, EventData]) -> None:
def _collect_carry_over(self, data: dict[int, EventData]) -> None:
""" For live data, collect carried over data from the previous update and merge into the
current data dictionary.
@ -357,8 +353,8 @@ class _Cache():
logger.debug("Merged carry over data: %s", update)
def _process_data(self,
data: Dict[int, EventData],
is_live: bool) -> Tuple[List[float], List[List[float]]]:
data: dict[int, EventData],
is_live: bool) -> tuple[list[float], list[list[float]]]:
""" Process live update data.
Live data requires different processing as often we will only have partial data for the
@ -383,8 +379,8 @@ class _Cache():
timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss)
for idx in sorted(data)])
l_loss: List[List[float]] = list(loss)
l_timestamps: List[float] = list(timestamps)
l_loss: list[list[float]] = list(loss)
l_timestamps: list[float] = list(timestamps)
if len(l_loss[-1]) != len(self._loss_labels):
logger.debug("Truncated loss found. loss count: %s", len(l_loss))
@ -418,8 +414,8 @@ class _Cache():
self._data[session_id].add_live_data(timestamps, loss)
def get_data(self, session_id: int, metric: Literal["loss", "timestamps"]
) -> Optional[Dict[int, Dict[str, Union[np.ndarray, List[str]]]]]:
def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"]
) -> dict[int, dict[str, np.ndarray | list[str]]] | None:
""" Retrieve the decompressed cached data from the cache for the given session id.
Parameters
@ -445,10 +441,10 @@ class _Cache():
return None
raw = {session_id: data}
retval: Dict[int, Dict[str, Union[np.ndarray, List[str]]]] = {}
retval: dict[int, dict[str, np.ndarray | list[str]]] = {}
for idx, data in raw.items():
array = data.loss if metric == "loss" else data.timestamps
val: Dict[str, Union[np.ndarray, List[str]]] = {str(metric): array}
val: dict[str, np.ndarray | list[str]] = {str(metric): array}
if metric == "loss":
val["labels"] = data.labels
retval[idx] = val
@ -488,7 +484,7 @@ class TensorBoardLogs():
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def session_ids(self) -> List[int]:
def session_ids(self) -> list[int]:
""" list[int]: Sorted list of integers of available session ids. """
return self._log_files.session_ids
@ -539,7 +535,7 @@ class TensorBoardLogs():
parser = _EventParser(iterator, self._cache, live_data)
parser.cache_events(session_id)
def _check_cache(self, session_id: Optional[int] = None) -> None:
def _check_cache(self, session_id: int | None = None) -> None:
""" Check if the given session_id has been cached and if not, cache it.
Parameters
@ -557,7 +553,7 @@ class TensorBoardLogs():
if not self._cache.is_cached(idx):
self._cache_data(idx)
def get_loss(self, session_id: Optional[int] = None) -> Dict[int, Dict[str, np.ndarray]]:
def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]:
""" Read the loss from the TensorBoard event logs
Parameters
@ -573,7 +569,7 @@ class TensorBoardLogs():
and list of loss values for each step
"""
logger.debug("Getting loss: (session_id: %s)", session_id)
retval: Dict[int, Dict[str, np.ndarray]] = {}
retval: dict[int, dict[str, np.ndarray]] = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
full_data = self._cache.get_data(idx, "loss")
@ -588,7 +584,7 @@ class TensorBoardLogs():
for key, val in retval.items()})
return retval
def get_timestamps(self, session_id: Optional[int] = None) -> Dict[int, np.ndarray]:
def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]:
""" Read the timestamps from the TensorBoard logs.
As loss timestamps are slightly different for each loss, we collect the timestamp from the
@ -608,7 +604,7 @@ class TensorBoardLogs():
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
session_id, self._is_training)
retval: Dict[int, np.ndarray] = {}
retval: dict[int, np.ndarray] = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
data = self._cache.get_data(idx, "timestamps")
@ -640,7 +636,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
self._live_data = live_data
self._cache = cache
self._iterator = self._get_latest_live(iterator) if live_data else iterator
self._loss_labels: List[str] = []
self._loss_labels: list[str] = []
logger.debug("Initialized: %s", self.__class__.__name__)
@classmethod
@ -683,7 +679,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
The session id that the data is being cached for
"""
assert self._iterator is not None
data: Dict[int, EventData] = {}
data: dict[int, EventData] = {}
try:
for record in self._iterator:
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
@ -743,7 +739,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
logger.debug("Collated loss labels: %s", self._loss_labels)
@classmethod
def _get_outputs(cls, model_config: Dict[str, Any]) -> np.ndarray:
def _get_outputs(cls, model_config: dict[str, T.Any]) -> np.ndarray:
""" Obtain the output names, instance index and output index for the given model.
If there is only a single output, the shape of the array is expanded to remain consistent

View file

@ -5,15 +5,15 @@ Holds the globally loaded training session. This will either be a user selected
the analysis tab) or the currently training session.
"""
from __future__ import annotations
import logging
import time
import os
import time
import typing as T
import warnings
from math import ceil
from threading import Event
from typing import Any, cast, Dict, List, Optional, overload, Tuple, Union
import numpy as np
@ -31,12 +31,12 @@ class GlobalSession():
"""
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._state: Dict[str, Any] = {}
self._state: dict[str, T.Any] = {}
self._model_dir = ""
self._model_name = ""
self._tb_logs: Optional[TensorBoardLogs] = None
self._summary: Optional[SessionsSummary] = None
self._tb_logs: TensorBoardLogs | None = None
self._summary: SessionsSummary | None = None
self._is_training = False
self._is_querying = Event()
@ -60,7 +60,7 @@ class GlobalSession():
return os.path.join(self._model_dir, self._model_name)
@property
def batch_sizes(self) -> Dict[int, int]:
def batch_sizes(self) -> dict[int, int]:
""" dict: The batch sizes for each session_id for the model. """
if not self._state:
return {}
@ -68,7 +68,7 @@ class GlobalSession():
for sess_id, sess in self._state.get("sessions", {}).items()}
@property
def full_summary(self) -> List[dict]:
def full_summary(self) -> list[dict]:
""" list: List of dictionaries containing summary statistics for each session id. """
assert self._summary is not None
return self._summary.get_summary_stats()
@ -83,7 +83,7 @@ class GlobalSession():
return self._state["sessions"][max_id]["no_logs"]
@property
def session_ids(self) -> List[int]:
def session_ids(self) -> list[int]:
""" list: The sorted list of all existing session ids in the state file """
if self._tb_logs is None:
return []
@ -164,7 +164,7 @@ class GlobalSession():
self._is_training = False
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]:
def get_loss(self, session_id: int | None) -> dict[str, np.ndarray]:
""" Obtain the loss values for the given session_id.
Parameters
@ -186,11 +186,11 @@ class GlobalSession():
assert self._tb_logs is not None
loss_dict = self._tb_logs.get_loss(session_id=session_id)
if session_id is None:
all_loss: Dict[str, List[float]] = {}
all_loss: dict[str, list[float]] = {}
for key in sorted(loss_dict):
for loss_key, loss in loss_dict[key].items():
all_loss.setdefault(loss_key, []).extend(loss)
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
retval: dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
for key, val in all_loss.items()}
else:
retval = loss_dict.get(session_id, {})
@ -199,11 +199,11 @@ class GlobalSession():
self._is_querying.clear()
return retval
@overload
def get_timestamps(self, session_id: None) -> Dict[int, np.ndarray]:
@T.overload
def get_timestamps(self, session_id: None) -> dict[int, np.ndarray]:
...
@overload
@T.overload
def get_timestamps(self, session_id: int) -> np.ndarray:
...
@ -247,7 +247,7 @@ class GlobalSession():
continue
break
def get_loss_keys(self, session_id: Optional[int]) -> List[str]:
def get_loss_keys(self, session_id: int | None) -> list[str]:
""" Obtain the loss keys for the given session_id.
Parameters
@ -268,7 +268,7 @@ class GlobalSession():
in self._tb_logs.get_loss(session_id=session_id).items()}
if session_id is None:
retval: List[str] = list(set(loss_key
retval: list[str] = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
else:
@ -293,11 +293,11 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
self._session = session
self._state = session._state
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {}
self._per_session_stats: List[Dict[str, Any]] = []
self._time_stats: dict[int, dict[str, float | int]] = {}
self._per_session_stats: list[dict[str, T.Any]] = []
logger.debug("Initialized %s", self.__class__.__name__)
def get_summary_stats(self) -> List[dict]:
def get_summary_stats(self) -> list[dict]:
""" Compile the individual session statistics and calculate the total.
Format the stats for display
@ -336,14 +336,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0,
"end_time": np.max(timestamps) if np.any(timestamps) else 0,
"iterations": timestamps.shape[0] if np.any(timestamps) else 0}
for sess_id, timestamps in cast(Dict[int, np.ndarray],
for sess_id, timestamps in T.cast(dict[int, np.ndarray],
self._session.get_timestamps(None)).items()}
elif _SESSION.is_training:
logger.debug("Updating summary time stamps for training session")
session_id = _SESSION.session_ids[-1]
latest = cast(np.ndarray, self._session.get_timestamps(session_id))
latest = T.cast(np.ndarray, self._session.get_timestamps(session_id))
self._time_stats[session_id] = {
"start_time": np.min(latest) if np.any(latest) else 0,
@ -392,7 +392,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
/ stats["elapsed"] if stats["elapsed"] > 0 else 0)
logger.debug("per_session_stats: %s", self._per_session_stats)
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]:
def _collate_stats(self, session_id: int) -> dict[str, int | float]:
""" Collate the session summary statistics for the given session ID.
Parameters
@ -422,7 +422,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
logger.debug(retval)
return retval
def _total_stats(self) -> Dict[str, Union[str, int, float]]:
def _total_stats(self) -> dict[str, str | int | float]:
""" Compile the Totals stats.
Totals are fully calculated each time as they will change on the basis of the training
session.
@ -459,7 +459,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
logger.debug(totals)
return totals
def _format_stats(self, compiled_stats: List[dict]) -> List[dict]:
def _format_stats(self, compiled_stats: list[dict]) -> list[dict]:
""" Format for the incoming list of statistics for display.
Parameters
@ -489,7 +489,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
return retval
@classmethod
def _convert_time(cls, timestamp: float) -> Tuple[str, str, str]:
def _convert_time(cls, timestamp: float) -> tuple[str, str, str]:
""" Convert time stamp to total hours, minutes and seconds.
Parameters
@ -534,8 +534,8 @@ class Calculations():
"""
def __init__(self, session_id,
display: str = "loss",
loss_keys: Union[List[str], str] = "loss",
selections: Union[List[str], str] = "raw",
loss_keys: list[str] | str = "loss",
selections: list[str] | str = "raw",
avg_samples: int = 500,
smooth_amount: float = 0.90,
flatten_outliers: bool = False) -> None:
@ -552,13 +552,13 @@ class Calculations():
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
self._selections = selections if isinstance(selections, list) else [selections]
self._is_totals = session_id is None
self._args: Dict[str, Union[int, float]] = {"avg_samples": avg_samples,
self._args: dict[str, int | float] = {"avg_samples": avg_samples,
"smooth_amount": smooth_amount,
"flatten_outliers": flatten_outliers}
self._iterations = 0
self._limit = 0
self._start_iteration = 0
self._stats: Dict[str, np.ndarray] = {}
self._stats: dict[str, np.ndarray] = {}
self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)
@ -573,11 +573,11 @@ class Calculations():
return self._start_iteration
@property
def stats(self) -> Dict[str, np.ndarray]:
def stats(self) -> dict[str, np.ndarray]:
""" dict: The final calculated statistics """
return self._stats
def refresh(self) -> Optional["Calculations"]:
def refresh(self) -> Calculations | None:
""" Refresh the stats """
logger.debug("Refreshing")
if not _SESSION.is_loaded:
@ -736,7 +736,8 @@ class Calculations():
"""
logger.debug("Calculating rate")
batch_size = _SESSION.batch_sizes[self._session_id] * 2
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id)))
retval = batch_size / np.diff(T.cast(np.ndarray,
_SESSION.get_timestamps(self._session_id)))
logger.debug("Calculated rate: Item_count: %s", len(retval))
return retval
@ -757,7 +758,7 @@ class Calculations():
logger.debug("Calculating totals rate")
batchsizes = _SESSION.batch_sizes
total_timestamps = _SESSION.get_timestamps(None)
rate: List[float] = []
rate: list[float] = []
for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id]
@ -797,7 +798,7 @@ class Calculations():
The moving average for the given data
"""
logger.debug("Calculating Average. Data points: %s", len(data))
window = cast(int, self._args["avg_samples"])
window = T.cast(int, self._args["avg_samples"])
pad = ceil(window / 2)
datapoints = data.shape[0]
@ -953,7 +954,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
def _ewma_vectorized(self,
data: np.ndarray,
out: np.ndarray,
offset: Optional[float] = None) -> None:
offset: float | None = None) -> None:
""" Calculates the exponential moving average over a vector. Will fail for large inputs.
The result is processed in place into the array passed to the `out` parameter

View file

@ -5,10 +5,10 @@ import logging
import re
import tkinter as tk
import typing as T
from tkinter import colorchooser, ttk
from itertools import zip_longest
from functools import partial
from typing import Any, Dict
from _tkinter import Tcl_Obj, TclError
@ -24,7 +24,9 @@ _ = _LANG.gettext
# We store Tooltips, ContextMenus and Commands globally when they are created
# Because we need to add them back to newly cloned widgets (they are not easily accessible from
# original config or are prone to getting destroyed when the original widget is destroyed)
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={})
_RECREATE_OBJECTS: dict[str, dict[str, T.Any]] = {"tooltips": {},
"commands": {},
"contextmenus": {}}
def _get_tooltip(widget, text=None, text_variable=None):
@ -154,17 +156,17 @@ class ControlPanelOption():
self.dtype = dtype
self.sysbrowser = sysbrowser
self._command = command
self._options = dict(title=title,
subgroup=subgroup,
group=group,
default=default,
initial_value=initial_value,
choices=choices,
is_radio=is_radio,
is_multi_option=is_multi_option,
rounding=rounding,
min_max=min_max,
helptext=helptext)
self._options = {"title": title,
"subgroup": subgroup,
"group": group,
"default": default,
"initial_value": initial_value,
"choices": choices,
"is_radio": is_radio,
"is_multi_option": is_multi_option,
"rounding": rounding,
"min_max": min_max,
"helptext": helptext}
self.control = self.get_control()
self.tk_var = self.get_tk_var(initial_value, track_modified)
logger.debug("Initialized %s", self.__class__.__name__)
@ -421,7 +423,7 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
self.group_frames = {}
self._sub_group_frames = {}
canvas_kwargs = dict(bd=0, highlightthickness=0, bg=self._theme["panel_background"])
canvas_kwargs = {"bd": 0, "highlightthickness": 0, "bg": self._theme["panel_background"]}
self._canvas = tk.Canvas(self, **canvas_kwargs)
self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
@ -525,8 +527,8 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW)
self.group_frames[group] = dict(frame=retval,
chkbtns=self.checkbuttons_frame(retval))
self.group_frames[group] = {"frame": retval,
"chkbtns": self.checkbuttons_frame(retval)}
group_frame = self.group_frames[group]
return group_frame
@ -720,12 +722,12 @@ class AutoFillContainer():
"""
retval = {}
if widget.__class__.__name__ == "MultiOption":
retval = dict(value=widget._value, # pylint:disable=protected-access
variable=widget._master_variable) # pylint:disable=protected-access
retval = {"value": widget._value, # pylint:disable=protected-access
"variable": widget._master_variable} # pylint:disable=protected-access
elif widget.__class__.__name__ == "ToggledFrame":
# Toggled Frames need to have their variable tracked
retval = dict(text=widget._text, # pylint:disable=protected-access
toggle_var=widget._toggle_var) # pylint:disable=protected-access
retval = {"text": widget._text, # pylint:disable=protected-access
"toggle_var": widget._toggle_var} # pylint:disable=protected-access
return retval
def get_all_children_config(self, widget, child_list):
@ -988,7 +990,7 @@ class ControlBuilder():
if self.option.control != ttk.Checkbutton:
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
if self.option.helptext is not None and not self.helpset:
tooltip_kwargs = dict(text=self.option.helptext)
tooltip_kwargs = {"text": self.option.helptext}
if self.option.sysbrowser is not None:
tooltip_kwargs["text_variable"] = self.option.tk_var
_get_tooltip(ctl, **tooltip_kwargs)
@ -1071,7 +1073,7 @@ class ControlBuilder():
"rounding: %s, min_max: %s)", self.option.name, self.option.dtype,
self.option.rounding, self.option.min_max)
validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float
vcmd = (self.frame.register(validate))
vcmd = self.frame.register(validate)
tbox = tk.Entry(self.frame,
width=8,
textvariable=self.option.tk_var,
@ -1246,15 +1248,15 @@ class FileBrowser():
@property
def helptext(self):
""" Dict containing tooltip text for buttons """
retval = dict(folder=_("Select a folder..."),
load=_("Select a file..."),
load2=_("Select a file..."),
picture=_("Select a folder of images..."),
video=_("Select a video..."),
model=_("Select a model folder..."),
multi_load=_("Select one or more files..."),
context=_("Select a file or folder..."),
save_as=_("Select a save location..."))
retval = {"folder": _("Select a folder..."),
"load": _("Select a file..."),
"load2": _("Select a file..."),
"picture": _("Select a folder of images..."),
"video": _("Select a video..."),
"model": _("Select a model folder..."),
"multi_load": _("Select one or more files..."),
"context": _("Select a file or folder..."),
"save_as": _("Select a save location...")}
return retval
@staticmethod

View file

@ -4,11 +4,10 @@ import datetime
import gettext
import logging
import os
import sys
import tkinter as tk
import typing as T
from tkinter import ttk
from typing import Dict, Optional, Tuple
from lib.training.preview_tk import PreviewTk
@ -19,11 +18,6 @@ from .analysis import Calculations, Session
from .control_helper import set_slider_rounding
from .utils import FileHandler, get_config, get_images, preview_trigger
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# LOCALES
@ -92,7 +86,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
logger.debug("Initializing %s (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs)
self._preview = get_images().preview_train
self._display: Optional[PreviewTk] = None
self._display: PreviewTk | None = None
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
@ -177,9 +171,9 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
tab_name: str,
helptext: str,
wait_time: int,
command: Optional[str] = None) -> None:
self._trace_vars: Dict[Literal["smoothgraph", "display_iterations"],
Tuple[tk.BooleanVar, str]] = {}
command: str | None = None) -> None:
self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
tuple[tk.BooleanVar, str]] = {}
super().__init__(parent, tab_name, helptext, wait_time, command)
def set_vars(self) -> None:
@ -446,7 +440,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
def _add_trace_variables(self) -> None:
""" Add tracing for when the option sliders are updated, for updating the graph. """
for name, action in zip(get_args(Literal["smoothgraph", "display_iterations"]),
for name, action in zip(T.get_args(T.Literal["smoothgraph", "display_iterations"]),
(self._smooth_amount_callback, self._iteration_limit_callback)):
var = self.vars[name]
if name not in self._trace_vars:

View file

@ -1,12 +1,13 @@
#!/usr/bin python3
""" Graph functions for Display Frame area of the Faceswap GUI """
from __future__ import annotations
import datetime
import logging
import os
import tkinter as tk
import typing as T
from tkinter import ttk
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
from math import ceil, floor
import numpy as np
@ -20,7 +21,7 @@ from matplotlib.backend_bases import NavigationToolbar2
from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from matplotlib.lines import Line2D
matplotlib.use("TkAgg")
@ -49,8 +50,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._ylabel = ylabel
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
self._lines: List["Line2D"] = []
self._toolbar: Optional["NavigationToolbar"] = None
self._lines: list[Line2D] = []
self._toolbar: "NavigationToolbar" | None = None
self._fig = Figure(figsize=(4, 4), dpi=75)
self._ax1 = self._fig.add_subplot(1, 1, 1)
@ -129,7 +130,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._ax1.set_ylim(0.00, 100.0)
self._ax1.set_xlim(0, 1)
def _axes_limits_set(self, data: List[float]) -> None:
def _axes_limits_set(self, data: list[float]) -> None:
""" Set the axes limits.
Parameters
@ -154,7 +155,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._axes_limits_set_default()
@staticmethod
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]:
def _axes_data_get_min_max(data: list[float]) -> tuple[float, float]:
""" Obtain the minimum and maximum values for the y-axis from the given data points.
Parameters
@ -188,7 +189,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
logger.debug("yscale: '%s'", scale)
self._ax1.set_yscale(scale)
def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]:
def _lines_sort(self, keys: list[str]) -> list[list[str | int | tuple[float]]]:
""" Sort the data keys into consistent order and set line color map and line width.
Parameters
@ -202,8 +203,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
A list of loss keys with their corresponding line formatting and color information
"""
logger.trace("Sorting lines") # type:ignore[attr-defined]
raw_lines: List[List[str]] = []
sorted_lines: List[List[str]] = []
raw_lines: list[list[str]] = []
sorted_lines: list[list[str]] = []
for key in sorted(keys):
title = key.replace("_", " ").title()
if key.startswith("raw"):
@ -217,7 +218,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return lines
@staticmethod
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int:
def _lines_groupsize(raw_lines: list[list[str]], sorted_lines: list[list[str]]) -> int:
""" Get the number of items in each group.
If raw data isn't selected, then check the length of remaining groups until something is
@ -246,8 +247,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return groupsize
def _lines_style(self,
lines: List[List[str]],
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]:
lines: list[list[str]],
groupsize: int) -> list[list[str | int | tuple[float]]]:
""" Obtain the color map and line width for each group.
Parameters
@ -266,13 +267,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
groups = int(len(lines) / groupsize)
colours = self._lines_create_colors(groupsize, groups)
widths = list(range(1, groups + 1))
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines)
retval = T.cast(list[list[str | int | tuple[float]]], lines)
for idx, item in enumerate(retval):
linewidth = widths[idx // groupsize]
item.extend((linewidth, colours[idx]))
return retval
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]:
def _lines_create_colors(self, groupsize: int, groups: int) -> list[tuple[float]]:
""" Create the color maps.
Parameters
@ -336,8 +337,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
super().__init__(parent, data, ylabel)
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask
self._displayed_keys: List[str] = []
self._thread: LongRunningTask | None = None # Thread for LongRunningTask
self._displayed_keys: list[str] = []
self._add_callback()
def _add_callback(self) -> None:
@ -352,7 +353,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def refresh(self, *args) -> None: # pylint: disable=unused-argument
""" Read the latest loss data and apply to current graph """
refresh_var = cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
refresh_var = T.cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
if not refresh_var.get() and self._thread is None:
return
@ -533,7 +534,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
text: str,
image_file: str,
toggle: bool,
command) -> Union[ttk.Button, ttk.Checkbutton]:
command) -> ttk.Button | ttk.Checkbutton:
""" Override the default button method to use our icons and ttk widgets for
consistent GUI layout.
@ -563,7 +564,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
img = get_images().icons[icon]
if not toggle:
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame,
btn: ttk.Button | ttk.Checkbutton = ttk.Button(frame,
text=text,
image=img,
command=command)

View file

@ -1,6 +1,6 @@
#!/usr/bin python3
""" The Menu Bars for faceswap GUI """
from __future__ import annotations
import gettext
import locale
import logging
@ -33,7 +33,7 @@ _ = _LANG.gettext
_WORKING_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
_RESOURCES: T.List[T.Tuple[str, str]] = [
_RESOURCES: list[tuple[str, str]] = [
(_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"),
(_("Patreon - Support this project"), "https://www.patreon.com/faceswap"),
(_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"),
@ -48,7 +48,7 @@ class MainMenuBar(tk.Menu): # pylint:disable=too-many-ancestors
master: :class:`tkinter.Tk`
The root tkinter object
"""
def __init__(self, master: "FaceswapGui") -> None:
def __init__(self, master: FaceswapGui) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(master)
self.root = master
@ -431,7 +431,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
return True
@classmethod
def _get_branches(cls) -> T.Optional[str]:
def _get_branches(cls) -> str | None:
""" Get the available github branches
Returns
@ -453,7 +453,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
return stdout.decode(locale.getpreferredencoding(), errors="replace")
@classmethod
def _filter_branches(cls, stdout: str) -> T.List[str]:
def _filter_branches(cls, stdout: str) -> list[str]:
""" Filter the branches, remove duplicates and the current branch and return a sorted
list.
@ -548,7 +548,7 @@ class TaskBar(ttk.Frame): # pylint: disable=too-many-ancestors
self._section_separator()
@classmethod
def _loader_and_kwargs(cls, btntype: str) -> T.Tuple[str, T.Dict[str, bool]]:
def _loader_and_kwargs(cls, btntype: str) -> tuple[str, dict[str, bool]]:
""" Get the loader name and key word arguments for the given button type
Parameters

View file

@ -1,6 +1,6 @@
#!/usr/bin python3
""" The pop-up window of the Faceswap GUI for the setting of configuration options. """
from __future__ import annotations
from collections import OrderedDict
from configparser import ConfigParser
import gettext
@ -9,7 +9,8 @@ import os
import sys
import tkinter as tk
from tkinter import ttk
from typing import Dict, TYPE_CHECKING
import typing as T
from importlib import import_module
from lib.serializer import get_serializer
@ -18,7 +19,7 @@ from .control_helper import ControlPanel, ControlPanelOption
from .custom_widgets import Tooltip
from .utils import FileHandler, get_config, get_images, PATHCACHE
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.config import FaceswapConfig
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -124,7 +125,7 @@ class _ConfigurePlugins(tk.Toplevel):
super().__init__()
self._root = get_config().root
self._set_geometry()
self._tk_vars = dict(header=tk.StringVar())
self._tk_vars = {"header": tk.StringVar()}
theme = {**get_config().user_theme["group_panel"],
**get_config().user_theme["group_settings"]}
@ -402,7 +403,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
"""
def __init__(self, top_level, parent, configurations, tree, theme):
super().__init__(parent)
self._configs: Dict[str, "FaceswapConfig"] = configurations
self._configs: dict[str, FaceswapConfig] = configurations
self._theme = theme
self._tree = tree
self._vars = {}
@ -443,7 +444,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
sect = section.split(".")[-1]
# Elevate global to root
key = plugin if sect == "global" else f"{plugin}|{category}|{sect}"
retval[key] = dict(helptext=None, options=OrderedDict())
retval[key] = {"helptext": None, "options": OrderedDict()}
retval[key]["helptext"] = conf.defaults[section].helptext
for option, params in conf.defaults[section].items.items():
@ -632,7 +633,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
def _get_new_config(self,
page_only: bool,
config: "FaceswapConfig",
config: FaceswapConfig,
category: str,
lookup: str) -> ConfigParser:
""" Obtain a new configuration file for saving
@ -812,9 +813,9 @@ class _Presets():
return None
args = ("save_filename", "json") if action == "save" else ("filename", "json")
kwargs = dict(title=f"{action.title()} Preset...",
initial_folder=self._preset_path,
parent=self._parent)
kwargs = {"title": f"{action.title()} Preset...",
"initial_folder": self._preset_path,
"parent": self._parent}
if action == "save":
kwargs["initial_file"] = self._get_initial_filename()

View file

@ -8,7 +8,6 @@ import tkinter as tk
from dataclasses import dataclass, field
from tkinter import ttk
from typing import Dict, List, Optional, Tuple, Type, Union
from .control_helper import ControlBuilder, ControlPanelOption
from .custom_widgets import Tooltip
@ -66,7 +65,7 @@ class SessionTKVars:
outliers: tk.BooleanVar
avgiterations: tk.IntVar
smoothamount: tk.DoubleVar
loss_keys: Dict[str, tk.BooleanVar] = field(default_factory=dict)
loss_keys: dict[str, tk.BooleanVar] = field(default_factory=dict)
class SessionPopUp(tk.Toplevel):
@ -82,13 +81,13 @@ class SessionPopUp(tk.Toplevel):
logger.debug("Initializing: %s: (session_id: %s, data_points: %s)",
self.__class__.__name__, session_id, data_points)
super().__init__()
self._thread: Optional[LongRunningTask] = None # Thread for loading data in background
self._thread: LongRunningTask | None = None # Thread for loading data in background
self._default_view = "avg" if data_points > 1000 else "smoothed"
self._session_id = None if session_id == "Total" else int(session_id)
self._graph_frame = ttk.Frame(self)
self._graph: Optional[SessionGraph] = None
self._display_data: Optional[Calculations] = None
self._graph: SessionGraph | None = None
self._display_data: Calculations | None = None
self._vars = self._set_vars()
@ -172,7 +171,7 @@ class SessionPopUp(tk.Toplevel):
The frame that the options reside in
"""
logger.debug("Building Combo boxes")
choices = dict(Display=("Loss", "Rate"), Scale=("Linear", "Log"))
choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")}
for item in ["Display", "Scale"]:
var: tk.StringVar = getattr(self._vars, item.lower())
@ -273,11 +272,11 @@ class SessionPopUp(tk.Toplevel):
logger.debug("Building Slider Controls")
for item in ("avgiterations", "smoothamount"):
if item == "avgiterations":
dtype: Union[Type[int], Type[float]] = int
dtype: type[int] | type[float] = int
text = "Iterations to Average:"
default: Union[int, float] = 500
default: int | float = 500
rounding = 25
min_max: Tuple[int, Union[int, float]] = (25, 2500)
min_max: tuple[int, int | float] = (25, 2500)
elif item == "smoothamount":
dtype = float
text = "Smoothing Amount:"
@ -404,20 +403,20 @@ class SessionPopUp(tk.Toplevel):
str
The help text for the given action
"""
lookup = dict(
reload=_("Refresh graph"),
save=_("Save display data to csv"),
avgiterations=_("Number of data points to sample for rolling average"),
smoothamount=_("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum "
lookup = {
"reload": _("Refresh graph"),
"save": _("Save display data to csv"),
"avgiterations": _("Number of data points to sample for rolling average"),
"smoothamount": _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum "
"smoothing"),
outliers=_("Flatten data points that fall more than 1 standard deviation from the "
"outliers": _("Flatten data points that fall more than 1 standard deviation from the "
"mean to the mean value."),
avg=_("Display rolling average of the data"),
smoothed=_("Smooth the data"),
raw=_("Display raw data"),
trend=_("Display polynormal data trend"),
display=_("Set the data to display"),
scale=_("Change y-axis scale"))
"avg": _("Display rolling average of the data"),
"smoothed": _("Smooth the data"),
"raw": _("Display raw data"),
"trend": _("Display polynormal data trend"),
"display": _("Set the data to display"),
"scale": _("Change y-axis scale")}
return lookup.get(action.lower(), "")
def _compile_display_data(self) -> bool:
@ -446,13 +445,13 @@ class SessionPopUp(tk.Toplevel):
self._lbl_loading.pack(fill=tk.BOTH, expand=True)
self.update_idletasks()
kwargs = dict(session_id=self._session_id,
display=self._vars.display.get(),
loss_keys=loss_keys,
selections=selections,
avg_samples=self._vars.avgiterations.get(),
smooth_amount=self._vars.smoothamount.get(),
flatten_outliers=self._vars.outliers.get())
kwargs = {"session_id": self._session_id,
"display": self._vars.display.get(),
"loss_keys": loss_keys,
"selections": selections,
"avg_samples": self._vars.avgiterations.get(),
"smooth_amount": self._vars.smoothamount.get(),
"flatten_outliers": self._vars.outliers.get()}
self._thread = LongRunningTask(target=self._get_display_data,
kwargs=kwargs,
widget=self)
@ -491,7 +490,7 @@ class SessionPopUp(tk.Toplevel):
"""
return Calculations(**kwargs)
def _check_valid_selection(self, loss_keys: List[str], selections: List[str]) -> bool:
def _check_valid_selection(self, loss_keys: list[str], selections: list[str]) -> bool:
""" Check that there will be data to display.
Parameters
@ -530,7 +529,7 @@ class SessionPopUp(tk.Toplevel):
return False
return True
def _selections_to_list(self) -> List[str]:
def _selections_to_list(self) -> list[str]:
""" Compile checkbox selections to a list.
Returns

View file

@ -1,19 +1,20 @@
#!/usr/bin python3
""" Global configuration optiopns for the Faceswap GUI """
from __future__ import annotations
import logging
import os
import sys
import tkinter as tk
import typing as T
from dataclasses import dataclass, field
from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING
from lib.gui._config import Config as UserConfig
from lib.gui.project import Project, Tasks
from lib.gui.theme import Style
from .file_handler import FileHandler
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.gui.options import CliOptions
from lib.gui.custom_widgets import StatusBar
from lib.gui.command import CommandNotebook
@ -22,12 +23,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PATHCACHE = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache")
_CONFIG: Optional["Config"] = None
_CONFIG: Config | None = None
def initialize_config(root: tk.Tk,
cli_opts: Optional["CliOptions"],
statusbar: Optional["StatusBar"]) -> Optional["Config"]:
cli_opts: CliOptions | None,
statusbar: StatusBar | None) -> Config | None:
""" Initialize the GUI Master :class:`Config` and add to global constant.
This should only be called once on first GUI startup. Future access to :class:`Config`
@ -145,13 +146,13 @@ class GlobalVariables():
@dataclass
class _GuiObjects:
""" Data class for commonly accessed GUI Objects """
cli_opts: Optional["CliOptions"]
cli_opts: CliOptions | None
tk_vars: GlobalVariables
project: Project
tasks: Tasks
status_bar: Optional["StatusBar"]
default_options: Dict[str, Dict[str, Any]] = field(default_factory=dict)
command_notebook: Optional["CommandNotebook"] = None
status_bar: StatusBar | None
default_options: dict[str, dict[str, T.Any]] = field(default_factory=dict)
command_notebook: CommandNotebook | None = None
class Config():
@ -172,15 +173,15 @@ class Config():
"""
def __init__(self,
root: tk.Tk,
cli_opts: Optional["CliOptions"],
statusbar: Optional["StatusBar"]) -> None:
cli_opts: CliOptions | None,
statusbar: StatusBar | None) -> None:
logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)",
self.__class__.__name__, root, cli_opts, statusbar)
self._default_font = cast(dict, tk.font.nametofont("TkDefaultFont").configure())["family"]
self._constants = dict(
root=root,
scaling_factor=self._get_scaling(root),
default_font=self._default_font)
self._default_font = T.cast(dict,
tk.font.nametofont("TkDefaultFont").configure())["family"]
self._constants = {"root": root,
"scaling_factor": self._get_scaling(root),
"default_font": self._default_font}
self._gui_objects = _GuiObjects(
cli_opts=cli_opts,
tk_vars=GlobalVariables(),
@ -211,7 +212,7 @@ class Config():
# GUI Objects
@property
def cli_opts(self) -> "CliOptions":
def cli_opts(self) -> CliOptions:
""" :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """
# This should only be None when a separate tool (not main GUI) is used, at which point
# cli_opts do not exist
@ -234,12 +235,12 @@ class Config():
return self._gui_objects.tasks
@property
def default_options(self) -> Dict[str, Dict[str, Any]]:
def default_options(self) -> dict[str, dict[str, T.Any]]:
""" dict: The default options for all tabs """
return self._gui_objects.default_options
@property
def statusbar(self) -> "StatusBar":
def statusbar(self) -> StatusBar:
""" :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar
:class:`tkinter.ttk.Frame`. """
# This should only be None when a separate tool (not main GUI) is used, at which point
@ -248,31 +249,31 @@ class Config():
return self._gui_objects.status_bar
@property
def command_notebook(self) -> Optional["CommandNotebook"]:
def command_notebook(self) -> CommandNotebook | None:
""" :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """
return self._gui_objects.command_notebook
# Convenience GUI Objects
@property
def tools_notebook(self) -> "ToolsNotebook":
def tools_notebook(self) -> ToolsNotebook:
""" :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """
assert self.command_notebook is not None
return self.command_notebook.tools_notebook
@property
def modified_vars(self) -> Dict[str, "tk.BooleanVar"]:
def modified_vars(self) -> dict[str, tk.BooleanVar]:
""" dict: The command notebook modified tkinter variables. """
assert self.command_notebook is not None
return self.command_notebook.modified_vars
@property
def _command_tabs(self) -> Dict[str, int]:
def _command_tabs(self) -> dict[str, int]:
""" dict: Command tab titles with their IDs. """
assert self.command_notebook is not None
return self.command_notebook.tab_names
@property
def _tools_tabs(self) -> Dict[str, int]:
def _tools_tabs(self) -> dict[str, int]:
""" dict: Tools command tab titles with their IDs. """
assert self.command_notebook is not None
return self.command_notebook.tools_tab_names
@ -284,17 +285,17 @@ class Config():
return self._user_config
@property
def user_config_dict(self) -> Dict[str, Any]: # TODO Dataclass
def user_config_dict(self) -> dict[str, T.Any]: # TODO Dataclass
""" dict: The GUI config in dict form. """
return self._user_config.config_dict
@property
def user_theme(self) -> Dict[str, Any]: # TODO Dataclass
def user_theme(self) -> dict[str, T.Any]: # TODO Dataclass
""" dict: The GUI theme selection options. """
return self._user_theme
@property
def default_font(self) -> Tuple[str, int]:
def default_font(self) -> tuple[str, int]:
""" tuple: The selected font as configured in user settings. First item is the font (`str`)
second item the font size (`int`). """
font = self.user_config_dict["font"]
@ -328,7 +329,7 @@ class Config():
self._gui_objects.default_options = default
self.project.set_default_options()
def set_command_notebook(self, notebook: "CommandNotebook") -> None:
def set_command_notebook(self, notebook: CommandNotebook) -> None:
""" Set the command notebook to the :attr:`command_notebook` attribute
and enable the modified callback for :attr:`project`.
@ -385,7 +386,7 @@ class Config():
""" Reload the user config from file. """
self._user_config = UserConfig(None)
def set_cursor_busy(self, widget: Optional[tk.Widget] = None) -> None:
def set_cursor_busy(self, widget: tk.Widget | None = None) -> None:
""" Set the root or widget cursor to busy.
Parameters
@ -399,7 +400,7 @@ class Config():
component.config(cursor="watch") # type: ignore
component.update_idletasks()
def set_cursor_default(self, widget: Optional[tk.Widget] = None) -> None:
def set_cursor_default(self, widget: tk.Widget | None = None) -> None:
""" Set the root or widget cursor to default.
Parameters
@ -413,7 +414,7 @@ class Config():
component.config(cursor="") # type: ignore
component.update_idletasks()
def set_root_title(self, text: Optional[str] = None) -> None:
def set_root_title(self, text: str | None = None) -> None:
""" Set the main title text for Faceswap.
The title will always begin with 'Faceswap.py'. Additional text can be appended.

View file

@ -2,23 +2,14 @@
""" File browser utility functions for the Faceswap GUI. """
import logging
import platform
import sys
import tkinter as tk
from tkinter import filedialog
from typing import cast, Dict, IO, List, Optional, Tuple, Union
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
import typing as T
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_FILETYPE = Literal["default", "alignments", "config_project", "config_task",
_FILETYPE = T.Literal["default", "alignments", "config_project", "config_task",
"config_all", "csv", "image", "ini", "state", "log", "video"]
_HANDLETYPE = Literal["open", "save", "filename", "filename_multi", "save_filename",
_HANDLETYPE = T.Literal["open", "save", "filename", "filename_multi", "save_filename",
"context", "dir"]
@ -72,14 +63,14 @@ class FileHandler(): # pylint:disable=too-few-public-methods
def __init__(self,
handle_type: _HANDLETYPE,
file_type: Optional[_FILETYPE],
title: Optional[str] = None,
initial_folder: Optional[str] = None,
initial_file: Optional[str] = None,
command: Optional[str] = None,
action: Optional[str] = None,
variable: Optional[str] = None,
parent: Optional[tk.Frame] = None) -> None:
file_type: _FILETYPE | None,
title: str | None = None,
initial_folder: str | None = None,
initial_file: str | None = None,
command: str | None = None,
action: str | None = None,
variable: str | None = None,
parent: tk.Frame | None = None) -> None:
logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', "
"initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', "
"variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type,
@ -101,27 +92,27 @@ class FileHandler(): # pylint:disable=too-few-public-methods
logger.debug("Initialized %s", self.__class__.__name__)
@property
def _filetypes(self) -> Dict[str, List[Tuple[str, str]]]:
def _filetypes(self) -> dict[str, list[tuple[str, str]]]:
""" dict: The accepted extensions for each file type for opening/saving """
all_files = ("All files", "*.*")
filetypes = dict(
default=[all_files],
alignments=[("Faceswap Alignments", "*.fsa"), all_files],
config_project=[("Faceswap Project files", "*.fsw"), all_files],
config_task=[("Faceswap Task files", "*.fst"), all_files],
config_all=[("Faceswap Project and Task files", "*.fst *.fsw"), all_files],
csv=[("Comma separated values", "*.csv"), all_files],
image=[("Bitmap", "*.bmp"),
filetypes = {
"default": [all_files],
"alignments": [("Faceswap Alignments", "*.fsa"), all_files],
"config_project": [("Faceswap Project files", "*.fsw"), all_files],
"config_task": [("Faceswap Task files", "*.fst"), all_files],
"config_all": [("Faceswap Project and Task files", "*.fst *.fsw"), all_files],
"csv": [("Comma separated values", "*.csv"), all_files],
"image": [("Bitmap", "*.bmp"),
("JPG", "*.jpeg *.jpg"),
("PNG", "*.png"),
("TIFF", "*.tif *.tiff"),
all_files],
ini=[("Faceswap config files", "*.ini"), all_files],
json=[("JSON file", "*.json"), all_files],
model=[("Keras model files", "*.h5"), all_files],
state=[("State files", "*.json"), all_files],
log=[("Log files", "*.log"), all_files],
video=[("Audio Video Interleave", "*.avi"),
"ini": [("Faceswap config files", "*.ini"), all_files],
"json": [("JSON file", "*.json"), all_files],
"model": [("Keras model files", "*.h5"), all_files],
"state": [("State files", "*.json"), all_files],
"log": [("Log files", "*.log"), all_files],
"video": [("Audio Video Interleave", "*.avi"),
("Flash Video", "*.flv"),
("Matroska", "*.mkv"),
("MOV", "*.mov"),
@ -129,7 +120,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
("MPEG", "*.mpeg *.mpg *.ts *.vob"),
("WebM", "*.webm"),
("Windows Media Video", "*.wmv"),
all_files])
all_files]}
# Add in multi-select options and upper case extensions for Linux
for key in filetypes:
@ -142,14 +133,14 @@ class FileHandler(): # pylint:disable=too-few-public-methods
multi = [f"{key.title()} Files"]
multi.append(" ".join([ftype[1]
for ftype in filetypes[key] if ftype[0] != "All files"]))
filetypes[key].insert(0, cast(Tuple[str, str], tuple(multi)))
filetypes[key].insert(0, T.cast(tuple[str, str], tuple(multi)))
return filetypes
@property
def _contexts(self) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
def _contexts(self) -> dict[str, dict[str, str | dict[str, str]]]:
"""dict: Mapping of commands, actions and their corresponding file dialog for context
handle types. """
return dict(effmpeg=dict(input={"extract": "filename",
return {"effmpeg": {"input": {"extract": "filename",
"gen-vid": "dir",
"get-fps": "filename",
"get-info": "filename",
@ -157,17 +148,17 @@ class FileHandler(): # pylint:disable=too-few-public-methods
"rescale": "filename",
"rotate": "filename",
"slice": "filename"},
output={"extract": "dir",
"output": {"extract": "dir",
"gen-vid": "save_filename",
"get-fps": "nothing",
"get-info": "nothing",
"mux-audio": "save_filename",
"rescale": "save_filename",
"rotate": "save_filename",
"slice": "save_filename"}))
"slice": "save_filename"}}}
@classmethod
def _set_dummy_master(cls) -> Optional[tk.Frame]:
def _set_dummy_master(cls) -> tk.Frame | None:
""" Add an option to force black font on Linux file dialogs KDE issue that displays light
font on white background).
@ -183,7 +174,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
if platform.system().lower() == "linux":
frame = tk.Frame()
frame.option_add("*foreground", "black")
retval: Optional[tk.Frame] = frame
retval: tk.Frame | None = frame
else:
retval = None
return retval
@ -196,7 +187,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
del self._dummy_master
self._dummy_master = None
def _set_defaults(self) -> Dict[str, Optional[str]]:
def _set_defaults(self) -> dict[str, str | None]:
""" Set the default file type for the file dialog. Generally the first found file type
will be used, but this is overridden if it is not appropriate.
@ -205,7 +196,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
dict:
The default file extension for each file type
"""
defaults: Dict[str, Optional[str]] = {
defaults: dict[str, str | None] = {
key: next(ext for ext in val[0][1].split(" ")).replace("*", "")
for key, val in self._filetypes.items()}
defaults["default"] = None
@ -215,15 +206,15 @@ class FileHandler(): # pylint:disable=too-few-public-methods
return defaults
def _set_kwargs(self,
title: Optional[str],
initial_folder: Optional[str],
initial_file: Optional[str],
file_type: Optional[_FILETYPE],
command: Optional[str],
action: Optional[str],
variable: Optional[str],
parent: Optional[tk.Frame]
) -> Dict[str, Union[None, tk.Frame, str, List[Tuple[str, str]]]]:
title: str | None,
initial_folder: str | None,
initial_file: str | None,
file_type: _FILETYPE | None,
command: str | None,
action: str | None,
variable: str | None,
parent: tk.Frame | None
) -> dict[str, None | tk.Frame | str | list[tuple[str, str]]]:
""" Generate the required kwargs for the requested file dialog browser.
Parameters
@ -259,8 +250,8 @@ class FileHandler(): # pylint:disable=too-few-public-methods
title, initial_folder, initial_file, file_type, command, action, variable,
parent)
kwargs: Dict[str, Union[None, tk.Frame, str,
List[Tuple[str, str]]]] = dict(master=self._dummy_master)
kwargs: dict[str, None | tk.Frame | str | list[tuple[str, str]]] = {
"master": self._dummy_master}
if self._handletype.lower() == "context":
assert command is not None and action is not None and variable is not None
@ -304,20 +295,20 @@ class FileHandler(): # pylint:disable=too-few-public-methods
The variable associated with this file dialog
"""
if self._contexts[command].get(variable, None) is not None:
handletype = cast(Dict[str, Dict[str, Dict[str, str]]],
handletype = T.cast(dict[str, dict[str, dict[str, str]]],
self._contexts)[command][variable][action]
else:
handletype = cast(Dict[str, Dict[str, str]],
handletype = T.cast(dict[str, dict[str, str]],
self._contexts)[command][action]
logger.debug(handletype)
self._handletype = cast(_HANDLETYPE, handletype)
self._handletype = T.cast(_HANDLETYPE, handletype)
def _open(self) -> Optional[IO]:
def _open(self) -> T.IO | None:
""" Open a file. """
logger.debug("Popping Open browser")
return filedialog.askopenfile(**self._kwargs) # type: ignore
def _save(self) -> Optional[IO]:
def _save(self) -> T.IO | None:
""" Save a file. """
logger.debug("Popping Save browser")
return filedialog.asksaveasfile(**self._kwargs) # type: ignore
@ -337,7 +328,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
logger.debug("Popping Filename browser")
return filedialog.askopenfilename(**self._kwargs) # type: ignore
def _filename_multi(self) -> Tuple[str, ...]:
def _filename_multi(self) -> tuple[str, ...]:
""" Get multiple existing file locations. """
logger.debug("Popping Filename browser")
return filedialog.askopenfilenames(**self._kwargs) # type: ignore

View file

@ -1,10 +1,9 @@
#!/usr/bin python3
""" Utilities for handling images in the Faceswap GUI """
from __future__ import annotations
import logging
import os
import sys
from typing import cast, Dict, List, Optional, Sequence, Tuple
import typing as T
import cv2
import numpy as np
@ -14,15 +13,12 @@ from lib.training.preview_cv import PreviewBuffer
from .config import get_config, PATHCACHE
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_IMAGES: Optional["Images"] = None
_PREVIEW_TRIGGER: Optional["PreviewTrigger"] = None
_IMAGES: "Images" | None = None
_PREVIEW_TRIGGER: "PreviewTrigger" | None = None
TRAININGPREVIEW = ".gui_training_preview.png"
@ -51,7 +47,7 @@ def get_images() -> "Images":
return _IMAGES
def _get_previews(image_path: str) -> List[str]:
def _get_previews(image_path: str) -> list[str]:
""" Get the images stored within the given directory.
Parameters
@ -164,12 +160,12 @@ class PreviewExtract():
self._output_path = ""
self._modified: float = 0.0
self._filenames: List[str] = []
self._images: Optional[np.ndarray] = None
self._placeholder: Optional[np.ndarray] = None
self._filenames: list[str] = []
self._images: np.ndarray | None = None
self._placeholder: np.ndarray | None = None
self._preview_image: Optional[Image.Image] = None
self._preview_image_tk: Optional[ImageTk.PhotoImage] = None
self._preview_image: Image.Image | None = None
self._preview_image_tk: ImageTk.PhotoImage | None = None
logger.debug("Initialized %s", self.__class__.__name__)
@ -228,7 +224,7 @@ class PreviewExtract():
logger.debug("sorted folders: %s, return value: %s", folders, retval)
return retval
def _get_newest_filenames(self, image_files: List[str]) -> List[str]:
def _get_newest_filenames(self, image_files: list[str]) -> list[str]:
""" Return image filenames that have been modified since the last check.
Parameters
@ -281,8 +277,8 @@ class PreviewExtract():
return retval
def _process_samples(self,
samples: List[np.ndarray],
filenames: List[str],
samples: list[np.ndarray],
filenames: list[str],
num_images: int) -> bool:
""" Process the latest sample images into a displayable image.
@ -321,8 +317,8 @@ class PreviewExtract():
return True
def _load_images_to_cache(self,
image_files: List[str],
frame_dims: Tuple[int, int],
image_files: list[str],
frame_dims: tuple[int, int],
thumbnail_size: int) -> bool:
""" Load preview images to the image cache.
@ -349,7 +345,7 @@ class PreviewExtract():
logger.debug("num_images: %s", num_images)
if num_images == 0:
return False
samples: List[np.ndarray] = []
samples: list[np.ndarray] = []
start_idx = len(image_files) - num_images if len(image_files) > num_images else 0
show_files = sorted(image_files, key=os.path.getctime)[start_idx:]
dropped_files = []
@ -405,7 +401,7 @@ class PreviewExtract():
self._placeholder = placeholder
logger.debug("Created placeholder. shape: %s", placeholder.shape)
def _place_previews(self, frame_dims: Tuple[int, int]) -> Image.Image:
def _place_previews(self, frame_dims: tuple[int, int]) -> Image.Image:
""" Format the preview thumbnails stored in the cache into a grid fitting the display
panel.
@ -441,12 +437,12 @@ class PreviewExtract():
placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder)
samples = np.concatenate((samples, placeholder))
display = np.vstack([np.hstack(cast(Sequence, samples[row * cols: (row + 1) * cols]))
display = np.vstack([np.hstack(T.cast("Sequence", samples[row * cols: (row + 1) * cols]))
for row in range(rows)])
logger.debug("display shape: %s", display.shape)
return Image.fromarray(display)
def load_latest_preview(self, thumbnail_size: int, frame_dims: Tuple[int, int]) -> bool:
def load_latest_preview(self, thumbnail_size: int, frame_dims: tuple[int, int]) -> bool:
""" Load the latest preview image for extract and convert.
Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails
@ -524,7 +520,7 @@ class Images():
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._pathpreview = os.path.join(PATHCACHE, "preview")
self._pathoutput: Optional[str] = None
self._pathoutput: str | None = None
self._batch_mode = False
self._preview_train = PreviewTrain(self._pathpreview)
self._preview_extract = PreviewExtract(self._pathpreview)
@ -542,7 +538,7 @@ class Images():
return self._preview_extract
@property
def icons(self) -> Dict[str, ImageTk.PhotoImage]:
def icons(self) -> dict[str, ImageTk.PhotoImage]:
""" dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon
name (`str`) the value is the icon sized and formatted for display
(:class:`PIL.ImageTK.PhotoImage`).
@ -557,7 +553,7 @@ class Images():
return self._icons
@staticmethod
def _load_icons() -> Dict[str, ImageTk.PhotoImage]:
def _load_icons() -> dict[str, ImageTk.PhotoImage]:
""" Scan the icons cache folder and load the icons into :attr:`icons` for retrieval
throughout the GUI.
@ -569,7 +565,7 @@ class Images():
"""
size = get_config().user_config_dict.get("icon_size", 16)
size = int(round(size * get_config().scaling_factor))
icons: Dict[str, ImageTk.PhotoImage] = {}
icons: dict[str, ImageTk.PhotoImage] = {}
pathicons = os.path.join(PATHCACHE, "icons")
for fname in os.listdir(pathicons):
name, ext = os.path.splitext(fname)
@ -609,12 +605,12 @@ class PreviewTrigger():
"""
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
self._trigger_files = dict(update=os.path.join(PATHCACHE, ".preview_trigger"),
mask_toggle=os.path.join(PATHCACHE, ".preview_mask_toggle"))
self._trigger_files = {"update": os.path.join(PATHCACHE, ".preview_trigger"),
"mask_toggle": os.path.join(PATHCACHE, ".preview_mask_toggle")}
logger.debug("Initialized: %s (trigger_files: %s)",
self.__class__.__name__, self._trigger_files)
def set(self, trigger_type: Literal["update", "mask_toggle"]):
def set(self, trigger_type: T.Literal["update", "mask_toggle"]):
""" Place the trigger file into the cache folder
Parameters
@ -629,7 +625,7 @@ class PreviewTrigger():
pass
logger.debug("Set preview trigger: %s", trigger)
def clear(self, trigger_type: Optional[Literal["update", "mask_toggle"]] = None) -> None:
def clear(self, trigger_type: T.Literal["update", "mask_toggle"] | None = None) -> None:
""" Remove the trigger file from the cache folder.
Parameters

View file

@ -1,15 +1,17 @@
#!/usr/bin/env python3
""" Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """
from __future__ import annotations
import logging
import sys
import typing as T
from threading import Event, Thread
from typing import (Any, Callable, cast, Dict, Optional, Tuple, Type, TYPE_CHECKING)
from queue import Queue
from .config import get_config
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType
from lib.multithreading import _ErrorType
@ -31,15 +33,15 @@ class LongRunningTask(Thread):
cursor in the correct location. Default: ``None``.
"""
_target: Callable
_args: Tuple
_kwargs: Dict[str, Any]
_args: tuple
_kwargs: dict[str, T.Any]
_name: str
def __init__(self,
target: Optional[Callable] = None,
name: Optional[str] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None,
target: Callable | None = None,
name: str | None = None,
args: tuple = (),
kwargs: dict[str, T.Any] | None = None,
*,
daemon: bool = True,
widget=None):
@ -48,7 +50,7 @@ class LongRunningTask(Thread):
daemon)
super().__init__(target=target, name=name, args=args, kwargs=kwargs,
daemon=daemon)
self.err: "_ErrorType" = None
self.err: _ErrorType = None
self._widget = widget
self._config = get_config()
self._config.set_cursor_busy(widget=self._widget)
@ -70,7 +72,7 @@ class LongRunningTask(Thread):
retval = self._target(*self._args, **self._kwargs)
self._queue.put(retval)
except Exception: # pylint: disable=broad-except
self.err = cast(Tuple[Type[BaseException], BaseException, "TracebackType"],
self.err = T.cast(tuple[type[BaseException], BaseException, "TracebackType"],
sys.exc_info())
assert self.err is not None
logger.debug("Error in thread (%s): %s", self._name,
@ -81,7 +83,7 @@ class LongRunningTask(Thread):
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs
def get_result(self) -> Any:
def get_result(self) -> T.Any:
""" Return the result from the given task.
Returns

View file

@ -1,17 +1,17 @@
#!/usr/bin python3
""" Utilities for working with images and videos """
from __future__ import annotations
import logging
import re
import subprocess
import os
import struct
import sys
import typing as T
from ast import literal_eval
from bisect import bisect
from concurrent import futures
from typing import Optional, TYPE_CHECKING, Union
from zlib import crc32
import cv2
@ -24,7 +24,7 @@ from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderDict
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -558,7 +558,7 @@ def update_existing_metadata(filename, metadata):
def encode_image(image: np.ndarray,
extension: str,
metadata: Optional["PNGHeaderDict"] = None) -> bytes:
metadata: PNGHeaderDict | None = None) -> bytes:
""" Encode an image.
Parameters
@ -1433,8 +1433,8 @@ class ImagesSaver(ImageIO):
def _save(self,
filename: str,
image: Union[bytes, np.ndarray],
sub_folder: Optional[str]) -> None:
image: bytes | np.ndarray,
sub_folder: str | None) -> None:
""" Save a single image inside a ThreadPoolExecutor
Parameters
@ -1468,8 +1468,8 @@ class ImagesSaver(ImageIO):
def save(self,
filename: str,
image: Union[bytes, np.ndarray],
sub_folder: Optional[str] = None) -> None:
image: bytes | np.ndarray,
sub_folder: str | None = None) -> None:
""" Save the given image in the background thread
Ensure that :func:`close` is called once all save operations are complete.

View file

@ -117,7 +117,7 @@ class ColorSpaceConvert(): # pylint:disable=too-few-public-methods
self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32")
@classmethod
def _get_rgb_xyz_map(cls) -> T.Tuple[Tensor, Tensor]:
def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]:
""" Obtain the mapping and inverse mapping for rgb to xyz color space conversion.
Returns

View file

@ -11,7 +11,17 @@ import time
import traceback
from datetime import datetime
from typing import Union
# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
def _patched_format(self, record):
""" Autograph tf-2.10 has a bug with the 3.10 version of logging.PercentStyle._format(). It is
non-critical but spits out warnings. This is the Python 3.9 version of the function and should
be removed once fixed """
return self._fmt % record.__dict__ # pylint:disable=protected-access
setattr(logging.PercentStyle, "_format", _patched_format)
class FaceswapLogger(logging.Logger):
@ -76,11 +86,11 @@ class ColoredFormatter(logging.Formatter):
def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None:
super().__init__(fmt, **kwargs)
self._use_color = self._get_color_compatibility()
self._level_colors = dict(CRITICAL="\033[31m", # red
ERROR="\033[31m", # red
WARNING="\033[33m", # yellow
INFO="\033[32m", # green
VERBOSE="\033[34m") # blue
self._level_colors = {"CRITICAL": "\033[31m", # red
"ERROR": "\033[31m", # red
"WARNING": "\033[33m", # yellow
"INFO": "\033[32m", # green
"VERBOSE": "\033[34m"} # blue
self._default_color = "\033[0m"
self._newline_padding = self._get_newline_padding(pad_newlines, fmt)
@ -412,7 +422,7 @@ def _file_handler(loglevel,
return handler
def _stream_handler(loglevel: int, is_gui: bool) -> Union[logging.StreamHandler, TqdmHandler]:
def _stream_handler(loglevel: int, is_gui: bool) -> logging.StreamHandler | TqdmHandler:
""" Add a stream handler for the current Faceswap session. The stream handler will only ever
output at a maximum of VERBOSE level to avoid spamming the console.

View file

@ -1,8 +1,6 @@
""" Auto clipper for clipping gradients. """
from typing import List
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
class AutoClipper(): # pylint:disable=too-few-public-methods
@ -22,12 +20,56 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
original paper: https://arxiv.org/abs/2007.14469
"""
def __init__(self, clip_percentile: int, history_size: int = 10000):
self._clip_percentile = clip_percentile
self._clip_percentile = tf.cast(clip_percentile, tf.float64)
self._grad_history = tf.Variable(tf.zeros(history_size), trainable=False)
self._index = tf.Variable(0, trainable=False)
self._history_size = history_size
def __call__(self, grads_and_vars: List[tf.Tensor]) -> List[tf.Tensor]:
def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor:
""" Compute the clip percentile of the gradient history
Parameters
----------
grad_history: :class:`tensorflow.Tensor`
Tge gradient history to calculate the clip percentile for
Returns
-------
:class:`tensorflow.Tensor`
A rank(:attr:`clip_percentile`) `Tensor`
Notes
-----
Adapted from
https://github.com/tensorflow/probability/blob/r0.14/tensorflow_probability/python/stats/quantiles.py
to remove reliance on full tensorflow_probability libraray
"""
with tf.name_scope("percentile"):
frac_at_q_or_below = self._clip_percentile / 100.
sorted_hist = tf.sort(grad_history, axis=-1, direction="ASCENDING")
num = tf.cast(tf.shape(grad_history)[-1], tf.float64)
# get indices
indices = tf.round((num - 1) * frac_at_q_or_below)
indices = tf.clip_by_value(tf.cast(indices, tf.int32),
0,
tf.shape(grad_history)[-1] - 1)
gathered_hist = tf.gather(sorted_hist, indices, axis=-1)
# Propagate NaNs. Apparently tf.is_nan doesn't like other dtypes
nan_batch_members = tf.reduce_any(tf.math.is_nan(grad_history), axis=None)
right_rank_matched_shape = tf.pad(tf.shape(nan_batch_members),
paddings=[[0, tf.rank(self._clip_percentile)]],
constant_values=1)
nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape)
nan = np.array(np.nan, gathered_hist.dtype.as_numpy_dtype)
gathered_hist = tf.where(nan_batch_members, nan, gathered_hist)
return gathered_hist
def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]:
""" Call the AutoClip function.
Parameters
@ -40,8 +82,7 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
assign_idx = tf.math.mod(self._index, self._history_size)
self._grad_history = self._grad_history[assign_idx].assign(total_norm)
self._index = self._index.assign_add(1)
clip_value = tfp.stats.percentile(self._grad_history[: self._index],
q=self._clip_percentile)
clip_value = self._percentile(self._grad_history[: self._index])
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]
@classmethod

View file

@ -1,9 +1,9 @@
#!/usr/bin/env python3
""" Custom Feature Map Loss Functions for faceswap.py """
from __future__ import annotations
from dataclasses import dataclass, field
import logging
from typing import Any, Callable, Dict, Optional, List, Tuple
import typing as T
# Ignore linting errors from Tensorflow's thoroughly broken import system
import tensorflow as tf
@ -17,6 +17,9 @@ import numpy as np
from lib.model.nets import AlexNet, SqueezeNet
from lib.utils import GetModel
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__)
@ -39,10 +42,10 @@ class NetInfo:
"""
model_id: int = 0
model_name: str = ""
net: Optional[Callable] = None
init_kwargs: Dict[str, Any] = field(default_factory=dict)
net: Callable | None = None
init_kwargs: dict[str, T.Any] = field(default_factory=dict)
needs_init: bool = True
outputs: List[Layer] = field(default_factory=list)
outputs: list[Layer] = field(default_factory=list)
class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
@ -67,7 +70,7 @@ class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
logger.debug("Initialized: %s ", self.__class__.__name__)
@property
def _nets(self) -> Dict[str, NetInfo]:
def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net."""
return {
"alex": NetInfo(model_id=15,
@ -176,7 +179,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def _nets(self) -> Dict[str, NetInfo]:
def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net."""
return {
"alex": NetInfo(model_id=18,
@ -186,7 +189,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
"vgg16": NetInfo(model_id=20,
model_name="vgg16_lpips_v1.h5")}
def _linear_block(self, net_output_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
""" Build a linear block for a trunk network output.
Parameters
@ -319,7 +322,7 @@ class LPIPSLoss(): # pylint:disable=too-few-public-methods
tf.keras.mixed_precision.set_global_policy("mixed_float16")
logger.debug("Initialized: %s", self.__class__.__name__)
def _process_diffs(self, inputs: List[tf.Tensor]) -> List[tf.Tensor]:
def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]:
""" Perform processing on the Trunk Network outputs.
If :attr:`use_ldip` is enabled, process the diff values through the linear network,

View file

@ -1,10 +1,9 @@
#!/usr/bin/env python3
""" Custom Loss Functions for faceswap.py """
from __future__ import absolute_import
from __future__ import annotations
import logging
from typing import Callable, List, Tuple
import typing as T
import numpy as np
import tensorflow as tf
@ -13,6 +12,9 @@ import tensorflow as tf
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
from tensorflow.keras import backend as K # pylint:disable=import-error
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__)
@ -61,7 +63,7 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
self._ave_spectrum = ave_spectrum
self._log_matrix = log_matrix
self._batch_matrix = batch_matrix
self._dims: Tuple[int, int] = (0, 0)
self._dims: tuple[int, int] = (0, 0)
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
""" Crop the incoming batch of images into patches as defined by :attr:`_patch_factor.
@ -470,7 +472,7 @@ class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
return retval
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]:
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> list[tf.Tensor]:
""" Obtain the Laplacian Pyramid.
Parameters
@ -564,9 +566,9 @@ class LossWrapper(tf.keras.losses.Loss):
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
super().__init__(name="LossWrapper")
self._loss_functions: List[compile_utils.LossesContainer] = []
self._loss_weights: List[float] = []
self._mask_channels: List[int] = []
self._loss_functions: list[compile_utils.LossesContainer] = []
self._loss_weights: list[float] = []
self._mask_channels: list[int] = []
logger.debug("Initialized: %s", self.__class__.__name__)
def add_loss(self,
@ -628,7 +630,7 @@ class LossWrapper(tf.keras.losses.Loss):
y_true: tf.Tensor,
y_pred: tf.Tensor,
mask_channel: int,
mask_prop: float = 1.0) -> Tuple[tf.Tensor, tf.Tensor]:
mask_prop: float = 1.0) -> tuple[tf.Tensor, tf.Tensor]:
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
return the unmasked inputs.

View file

@ -2,9 +2,7 @@
""" TF Keras implementation of Perceptual Loss Functions for faceswap.py """
import logging
import sys
from typing import Dict, Optional, Tuple
import typing as T
import numpy as np
import tensorflow as tf
@ -14,11 +12,6 @@ from tensorflow.keras import backend as K # pylint:disable=import-error
from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__)
@ -101,7 +94,7 @@ class DSSIMObjective(): # pylint:disable=too-few-public-methods
"""
return K.depthwise_conv2d(image, kernel, strides=(1, 1), padding="valid")
def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
""" Obtain the structural similarity between a batch of true and predicted images.
Parameters
@ -330,8 +323,8 @@ class LDRFLIPLoss(): # pylint:disable=too-few-public-methods
lower_threshold_exponent: float = 0.4,
upper_threshold_exponent: float = 0.95,
epsilon: float = 1e-15,
pixels_per_degree: Optional[float] = None,
color_order: Literal["bgr", "rgb"] = "bgr") -> None:
pixels_per_degree: float | None = None,
color_order: T.Literal["bgr", "rgb"] = "bgr") -> None:
logger.debug("Initializing: %s (computed_distance_exponent '%s', feature_exponent: %s, "
"lower_threshold_exponent: %s, upper_threshold_exponent: %s, epsilon: %s, "
"pixels_per_degree: %s, color_order: %s)", self.__class__.__name__,
@ -525,7 +518,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
self._spatial_filters, self._radius = self._generate_spatial_filters()
self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb")
def _generate_spatial_filters(self) -> Tuple[tf.Tensor, int]:
def _generate_spatial_filters(self) -> tuple[tf.Tensor, int]:
""" Generates spatial contrast sensitivity filters with width depending on the number of
pixels per degree of visual angle of the observer for channels "A", "RG" and "BY"
@ -559,7 +552,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
b1_rg: float,
b2_rg: float,
b1_by: float,
b2_by: float) -> Tuple[np.ndarray, int]:
b2_by: float) -> tuple[np.ndarray, int]:
""" TODO docstring """
max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by])
delta_x = 1.0 / self._pixels_per_degree
@ -570,7 +563,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
return domain, radius
@classmethod
def _generate_weights(cls, channel: Dict[str, float], domain: np.ndarray) -> tf.Tensor:
def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> tf.Tensor:
""" TODO docstring """
a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"]
grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) +
@ -694,7 +687,7 @@ class MSSIMLoss(): # pylint:disable=too-few-public-methods
filter_size: int = 11,
filter_sigma: float = 1.5,
max_value: float = 1.0,
power_factors: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
power_factors: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
) -> None:
self.filter_size = filter_size
self.filter_sigma = filter_sigma

View file

@ -31,7 +31,7 @@ class _net(): # pylint:disable=too-few-public-methods
The input shape for the model. Default: ``None``
"""
def __init__(self,
input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None:
input_shape: tuple[int, int, int] | None = None) -> None:
logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape)
self._input_shape = (None, None, 3) if input_shape is None else input_shape
assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, (
@ -56,7 +56,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
input_shape, Tuple, optional
The input shape for the model. Default: ``None``
"""
def __init__(self, input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None:
def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None:
super().__init__(input_shape)
self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch
self._filters = [64, 192, 384, 256, 256] # Filters at each block
@ -108,7 +108,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
name=name)(var_x)
return var_x
def __call__(self) -> Model:
def __call__(self) -> tf.keras.models.Model:
""" Create the AlexNet Model
Returns
@ -189,7 +189,7 @@ class SqueezeNet(_net): # pylint:disable=too-few-public-methods
name=f"{name}.expand3x3")(squeezed)
return layers.Concatenate(axis=-1, name=name)([expand1, expand3])
def __call__(self) -> Model:
def __call__(self) -> tf.keras.models.Model:
""" Create the SqueezeNet Model
Returns

View file

@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG: dict = {}
_NAMES: T.Dict[str, int] = {}
_NAMES: dict[str, int] = {}
def set_config(configuration: dict) -> None:
@ -189,7 +189,7 @@ class Conv2DOutput(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int]],
kernel_size: int | tuple[int],
activation: str = "sigmoid",
padding: str = "same", **kwargs) -> None:
self._name = kwargs.pop("name") if "name" in kwargs else _get_name(
@ -265,11 +265,11 @@ class Conv2DBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 5,
strides: T.Union[int, T.Tuple[int, int]] = 2,
kernel_size: int | tuple[int, int] = 5,
strides: int | tuple[int, int] = 2,
padding: str = "same",
normalization: T.Optional[str] = None,
activation: T.Optional[str] = "leakyrelu",
normalization: str | None = None,
activation: str | None = "leakyrelu",
use_depthwise: bool = False,
relu_alpha: float = 0.1,
**kwargs) -> None:
@ -362,8 +362,8 @@ class SeparableConv2DBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 5,
strides: T.Union[int, T.Tuple[int, int]] = 2, **kwargs) -> None:
kernel_size: int | tuple[int, int] = 5,
strides: int | tuple[int, int] = 2, **kwargs) -> None:
self._name = _get_name(f"separableconv2d_{filters}")
logger.debug("name: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)",
self._name, filters, kernel_size, strides, kwargs)
@ -434,11 +434,11 @@ class UpscaleBlock(): # pylint:disable=too-few-public-methods
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
kernel_size: int | tuple[int, int] = 3,
padding: str = "same",
scale_factor: int = 2,
normalization: T.Optional[str] = None,
activation: T.Optional[str] = "leakyrelu",
normalization: str | None = None,
activation: str | None = "leakyrelu",
**kwargs) -> None:
self._name = _get_name(f"upscale_{filters}")
logger.debug("name: %s. filters: %s, kernel_size: %s, padding: %s, scale_factor: %s, "
@ -521,9 +521,9 @@ class Upscale2xBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
kernel_size: int | tuple[int, int] = 3,
padding: str = "same",
activation: T.Optional[str] = "leakyrelu",
activation: str | None = "leakyrelu",
interpolation: str = "bilinear",
sr_ratio: float = 0.5,
scale_factor: int = 2,
@ -615,9 +615,9 @@ class UpscaleResizeImagesBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
kernel_size: int | tuple[int, int] = 3,
padding: str = "same",
activation: T.Optional[str] = "leakyrelu",
activation: str | None = "leakyrelu",
scale_factor: int = 2,
interpolation: str = "bilinear") -> None:
self._name = _get_name(f"upscale_ri_{filters}")
@ -700,9 +700,9 @@ class UpscaleDNYBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
kernel_size: int | tuple[int, int] = 3,
padding: str = "same",
activation: T.Optional[str] = "leakyrelu",
activation: str | None = "leakyrelu",
size: int = 2,
interpolation: str = "bilinear",
**kwargs) -> None:
@ -757,7 +757,7 @@ class ResidualBlock(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
kernel_size: int | tuple[int, int] = 3,
padding: str = "same",
**kwargs) -> None:
self._name = _get_name(f"residual_{filters}")

View file

@ -1,9 +1,9 @@
#!/usr/bin python3
""" Settings manager for Keras Backend """
from __future__ import annotations
from contextlib import nullcontext
import logging
from typing import Callable, ContextManager, List, Optional, Union
import typing as T
import numpy as np
import tensorflow as tf
@ -14,6 +14,9 @@ from tensorflow.keras.models import load_model as k_load_model, Model # noqa:E5
from lib.utils import get_backend
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -52,9 +55,9 @@ class KSession():
def __init__(self,
name: str,
model_path: str,
model_kwargs: Optional[dict] = None,
model_kwargs: dict | None = None,
allow_growth: bool = False,
exclude_gpus: Optional[List[int]] = None,
exclude_gpus: list[int] | None = None,
cpu_mode: bool = False) -> None:
logger.trace("Initializing: %s (name: %s, model_path: %s, " # type:ignore
"model_kwargs: %s, allow_growth: %s, exclude_gpus: %s, cpu_mode: %s)",
@ -67,12 +70,12 @@ class KSession():
cpu_mode)
self._model_path = model_path
self._model_kwargs = {} if not model_kwargs else model_kwargs
self._model: Optional[Model] = None
self._model: Model | None = None
logger.trace("Initialized: %s", self.__class__.__name__,) # type:ignore
def predict(self,
feed: Union[List[np.ndarray], np.ndarray],
batch_size: Optional[int] = None) -> Union[List[np.ndarray], np.ndarray]:
feed: list[np.ndarray] | np.ndarray,
batch_size: int | None = None) -> list[np.ndarray] | np.ndarray:
""" Get predictions from the model.
This method is a wrapper for :func:`keras.predict()` function. For Tensorflow backends
@ -98,7 +101,7 @@ class KSession():
def _set_session(self,
allow_growth: bool,
exclude_gpus: list,
cpu_mode: bool) -> ContextManager:
cpu_mode: bool) -> T.ContextManager:
""" Sets the backend session options.
For CPU backends, this hides any GPUs from Tensorflow.

View file

@ -1,19 +1,22 @@
#!/usr/bin/env python3
""" Multithreading/processing utils for faceswap """
from __future__ import annotations
import logging
import typing as T
from multiprocessing import cpu_count
import queue as Queue
import sys
import threading
from types import TracebackType
from typing import Any, Callable, Dict, Generator, List, Tuple, Type, Optional, Set, Union
if T.TYPE_CHECKING:
from collections.abc import Callable, Generator
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_ErrorType = Optional[Union[Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[Any, Any, Any]]]
_THREAD_NAMES: Set[str] = set()
_ErrorType: T.TypeAlias = (tuple[type[BaseException], BaseException, TracebackType] |
tuple[T.Any, T.Any, T.Any] | None)
_THREAD_NAMES: set[str] = set()
def total_cpus():
@ -62,17 +65,17 @@ class FSThread(threading.Thread):
keyword arguments for the target invocation. Default: {}.
"""
_target: Callable
_args: Tuple
_kwargs: Dict[str, Any]
_args: tuple
_kwargs: dict[str, T.Any]
_name: str
def __init__(self,
target: Optional[Callable] = None,
name: Optional[str] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None,
target: Callable | None = None,
name: str | None = None,
args: tuple = (),
kwargs: dict[str, T.Any] | None = None,
*,
daemon: Optional[bool] = None) -> None:
daemon: bool | None = None) -> None:
super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
self.err: _ErrorType = None
@ -124,7 +127,7 @@ class MultiThread():
target: Callable,
*args,
thread_count: int = 1,
name: Optional[str] = None,
name: str | None = None,
**kwargs) -> None:
self._name = _get_name(name if name else target.__name__)
logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
@ -132,7 +135,7 @@ class MultiThread():
logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore
self.daemon = True
self._thread_count = thread_count
self._threads: List[FSThread] = []
self._threads: list[FSThread] = []
self._target = target
self._args = args
self._kwargs = kwargs
@ -144,7 +147,7 @@ class MultiThread():
return any(thread.err for thread in self._threads)
@property
def errors(self) -> List[_ErrorType]:
def errors(self) -> list[_ErrorType]:
""" list: List of thread error values """
return [thread.err for thread in self._threads if thread.err]
@ -253,9 +256,9 @@ class BackgroundGenerator(MultiThread):
def __init__(self,
generator: Callable,
prefetch: int = 1,
name: Optional[str] = None,
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None) -> None:
name: str | None = None,
args: tuple | None = None,
kwargs: dict[str, T.Any] | None = None) -> None:
super().__init__(name=name, target=self._run)
self.queue: Queue.Queue = Queue.Queue(prefetch)
self.generator = generator

View file

@ -6,7 +6,6 @@
import logging
import threading
from typing import Dict
from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
from time import sleep
@ -45,7 +44,7 @@ class _QueueManager():
logger.debug("Initializing %s", self.__class__.__name__)
self.shutdown = threading.Event()
self.queues: Dict[str, EventQueue] = {}
self.queues: dict[str, EventQueue] = {}
logger.debug("Initialized %s", self.__class__.__name__)
def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str:

View file

@ -6,8 +6,8 @@ import locale
import os
import platform
import sys
from subprocess import PIPE, Popen
from typing import List, Optional
import psutil
@ -21,14 +21,14 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
def __init__(self) -> None:
self._state_file = _State().state_file
self._configs = _Configs().configs
self._system = dict(platform=platform.platform(),
system=platform.system().lower(),
machine=platform.machine(),
release=platform.release(),
processor=platform.processor(),
cpu_count=os.cpu_count())
self._python = dict(implementation=platform.python_implementation(),
version=platform.python_version())
self._system = {"platform": platform.platform(),
"system": platform.system().lower(),
"machine": platform.machine(),
"release": platform.release(),
"processor": platform.processor(),
"cpu_count": os.cpu_count()}
self._python = {"implementation": platform.python_implementation(),
"version": platform.python_version()}
self._gpu = self._get_gpu_info()
self._cuda_check = CudaCheck()
@ -66,7 +66,7 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
(hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix))
else:
prefix = os.path.dirname(sys.prefix)
retval = (os.path.basename(prefix) == "envs")
retval = os.path.basename(prefix) == "envs"
return retval
@property
@ -295,7 +295,7 @@ class _Configs(): # pylint:disable=too-few-public-methods
except FileNotFoundError:
return ""
def _parse_configs(self, config_files: List[str]) -> str:
def _parse_configs(self, config_files: list[str]) -> str:
""" Parse the given list of config files into a human readable format.
Parameters
@ -399,7 +399,7 @@ class _State(): # pylint:disable=too-few-public-methods
return len(sys.argv) > 1 and sys.argv[1].lower() == "train"
@staticmethod
def _get_arg(*args: str) -> Optional[str]:
def _get_arg(*args: str) -> str | None:
""" Obtain the value for a given command line option from sys.argv.
Returns

View file

@ -1,16 +1,16 @@
#!/usr/bin/env python3
""" Package for handling alignments files, detected faces and aligned faces along with their
associated objects. """
from typing import Type, TYPE_CHECKING
from __future__ import annotations
import typing as T
from .augmentation import ImageAugmentation
from .generator import PreviewDataGenerator, TrainingDataGenerator
from .preview_cv import PreviewBuffer, TriggerType
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from .preview_cv import PreviewBase
Preview: Type[PreviewBase]
Preview: type[PreviewBase]
try:
from .preview_tk import PreviewTk as Preview

View file

@ -1,8 +1,9 @@
#!/usr/bin/env python3
""" Processes the augmentation of images for feeding into a Faceswap model. """
from __future__ import annotations
from dataclasses import dataclass
import logging
from typing import Dict, Tuple, TYPE_CHECKING
import typing as T
import cv2
import numexpr as ne
@ -11,7 +12,7 @@ from scipy.interpolate import griddata
from lib.image import batch_convert_color
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.config import ConfigValueType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -56,7 +57,7 @@ class AugConstants:
transform_zoom: float
transform_shift: float
warp_maps: np.ndarray
warp_pad: Tuple[int, int]
warp_pad: tuple[int, int]
warp_slices: slice
warp_lm_edge_anchors: np.ndarray
warp_lm_grids: np.ndarray
@ -79,7 +80,7 @@ class ImageAugmentation():
def __init__(self,
batchsize: int,
processing_size: int,
config: Dict[str, "ConfigValueType"]) -> None:
config: dict[str, ConfigValueType]) -> None:
logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, "
"config: %s)",
self.__class__.__name__, batchsize, processing_size, config)
@ -332,7 +333,7 @@ class ImageAugmentation():
slices = self._constants.warp_slices
rands = np.random.normal(size=(self._batchsize, 2, 5, 5),
scale=self._warp_scale).astype("float32")
batch_maps = ne.evaluate("m + r", local_dict=dict(m=self._constants.warp_maps, r=rands))
batch_maps = ne.evaluate("m + r", local_dict={"m": self._constants.warp_maps, "r": rands})
batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices]
for map_ in maps]
for maps in batch_maps])

View file

@ -1,11 +1,11 @@
#!/usr/bin/env python3
""" Holds the data cache for training data generators """
from __future__ import annotations
import logging
import os
import sys
import typing as T
from threading import Lock
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING
import cv2
import numpy as np
@ -16,25 +16,19 @@ from lib.align.aligned_face import CenteringType
from lib.image import read_image_batch, read_image_meta_batch
from lib.utils import FaceswapError
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict
from lib.config import ConfigValueType
logger = logging.getLogger(__name__)
_FACE_CACHES: Dict[str, "_Cache"] = {}
_FACE_CACHES: dict[str, "_Cache"] = {}
def get_cache(side: Literal["a", "b"],
filenames: Optional[List[str]] = None,
config: Optional[Dict[str, "ConfigValueType"]] = None,
size: Optional[int] = None,
coverage_ratio: Optional[float] = None) -> "_Cache":
def get_cache(side: T.Literal["a", "b"],
filenames: list[str] | None = None,
config: dict[str, ConfigValueType] | None = None,
size: int | None = None,
coverage_ratio: float | None = None) -> "_Cache":
""" Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then
create it.
@ -120,24 +114,24 @@ class _Cache():
The coverage ratio that the model is using.
"""
def __init__(self,
filenames: List[str],
config: Dict[str, "ConfigValueType"],
filenames: list[str],
config: dict[str, ConfigValueType],
size: int,
coverage_ratio: float) -> None:
logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)",
self.__class__.__name__, len(filenames), size, coverage_ratio)
self._lock = Lock()
self._cache_info = dict(cache_full=False, has_reset=False)
self._partially_loaded: List[str] = []
self._cache_info = {"cache_full": False, "has_reset": False}
self._partially_loaded: list[str] = []
self._image_count = len(filenames)
self._cache: Dict[str, DetectedFace] = {}
self._aligned_landmarks: Dict[str, np.ndarray] = {}
self._cache: dict[str, DetectedFace] = {}
self._aligned_landmarks: dict[str, np.ndarray] = {}
self._extract_version = 0.0
self._size = size
assert config["centering"] in get_args(CenteringType)
self._centering: CenteringType = cast(CenteringType, config["centering"])
assert config["centering"] in T.get_args(CenteringType)
self._centering: CenteringType = T.cast(CenteringType, config["centering"])
self._config = config
self._coverage_ratio = coverage_ratio
@ -153,7 +147,7 @@ class _Cache():
return self._cache_info["cache_full"]
@property
def aligned_landmarks(self) -> Dict[str, np.ndarray]:
def aligned_landmarks(self) -> dict[str, np.ndarray]:
""" dict: The filename as key, aligned landmarks as value. """
# Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate
# all of the aligned landmarks for the entire cache.
@ -185,7 +179,7 @@ class _Cache():
self._cache_info["has_reset"] = False
return retval
def get_items(self, filenames: List[str]) -> List[DetectedFace]:
def get_items(self, filenames: list[str]) -> list[DetectedFace]:
""" Obtain the cached items for a list of filenames. The returned list is in the same order
as the provided filenames.
@ -202,7 +196,7 @@ class _Cache():
"""
return [self._cache[os.path.basename(filename)] for filename in filenames]
def cache_metadata(self, filenames: List[str]) -> np.ndarray:
def cache_metadata(self, filenames: list[str]) -> np.ndarray:
""" Obtain the batch with metadata for items that need caching and cache DetectedFace
objects to :attr:`_cache`.
@ -267,7 +261,7 @@ class _Cache():
return batch
def pre_fill(self, filenames: List[str], side: Literal["a", "b"]) -> None:
def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None:
""" When warp to landmarks is enabled, the cache must be pre-filled, as each side needs
access to the other side's alignments.
@ -294,7 +288,7 @@ class _Cache():
self._cache[key] = detected_face
self._partially_loaded.append(key)
def _validate_version(self, png_meta: "PNGHeaderDict", filename: str) -> None:
def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None:
""" Validate that there are not a mix of v1.0 extracted faces and v2.x faces.
Parameters
@ -350,7 +344,7 @@ class _Cache():
def _load_detected_face(self,
filename: str,
alignments: "PNGHeaderAlignmentsDict") -> DetectedFace:
alignments: PNGHeaderAlignmentsDict) -> DetectedFace:
""" Load a :class:`DetectedFace` object and load its associated `aligned` property.
Parameters
@ -387,13 +381,13 @@ class _Cache():
The detected face object that holds the masks
"""
masks = [(self._get_face_mask(filename, detected_face))]
for area in get_args(Literal["eye", "mouth"]):
for area in T.get_args(T.Literal["eye", "mouth"]):
masks.append(self._get_localized_mask(filename, detected_face, area))
detected_face.store_training_masks(masks, delete_masks=True)
logger.trace("Stored masks for filename: %s)", filename) # type: ignore
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> Optional[np.ndarray]:
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None:
""" Obtain the training sized face mask from the :class:`DetectedFace` for the requested
mask type.
@ -448,7 +442,7 @@ class _Cache():
def _get_localized_mask(self,
filename: str,
detected_face: DetectedFace,
area: Literal["eye", "mouth"]) -> Optional[np.ndarray]:
area: T.Literal["eye", "mouth"]) -> np.ndarray | None:
""" Obtain a localized mask for the given area if it is required for training.
Parameters
@ -486,7 +480,7 @@ class RingBuffer(): # pylint: disable=too-few-public-methods
"""
def __init__(self,
batch_size: int,
image_shape: Tuple[int, int, int],
image_shape: tuple[int, int, int],
buffer_size: int = 2,
dtype: str = "uint8") -> None:
logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, "

View file

@ -1,13 +1,12 @@
#!/usr/bin/env python3
""" Handles Data Augmentation for feeding Faceswap Models """
from __future__ import annotations
import logging
import os
import sys
from concurrent import futures
import typing as T
from concurrent import futures
from random import shuffle, choice
from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING
import cv2
import numpy as np
@ -21,18 +20,14 @@ from lib.utils import FaceswapError
from . import ImageAugmentation
from .cache import get_cache, RingBuffer
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
from lib.config import ConfigValueType
from plugins.train.model._base import ModelBase
from .cache import _Cache
logger = logging.getLogger(__name__)
BatchType = Tuple[np.ndarray, List[np.ndarray]]
BatchType = tuple[np.ndarray, list[np.ndarray]]
class DataGenerator():
@ -57,10 +52,10 @@ class DataGenerator():
objects of this size from the iterator.
"""
def __init__(self,
config: Dict[str, "ConfigValueType"],
model: "ModelBase",
side: Literal["a", "b"],
images: List[str],
config: dict[str, ConfigValueType],
model: ModelBase,
side: T.Literal["a", "b"],
images: list[str],
batch_size: int) -> None:
logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore
"batch_size: %s, config: %s)", self.__class__.__name__, model.name, side,
@ -83,7 +78,7 @@ class DataGenerator():
self._buffer = RingBuffer(batch_size,
(self._process_size, self._process_size, self._total_channels),
dtype="uint8")
self._face_cache: "_Cache" = get_cache(side,
self._face_cache: _Cache = get_cache(side,
filenames=images,
config=self._config,
size=self._process_size,
@ -100,12 +95,12 @@ class DataGenerator():
channels += 1
mults = [area for area in ["eye", "mouth"]
if cast(int, self._config[f"{area}_multiplier"]) > 1]
if T.cast(int, self._config[f"{area}_multiplier"]) > 1]
if self._config["penalized_mask_loss"] and mults:
channels += len(mults)
return channels
def _get_output_sizes(self, model: "ModelBase") -> List[int]:
def _get_output_sizes(self, model: ModelBase) -> list[int]:
""" Obtain the size of each output tensor for the model.
Parameters
@ -222,7 +217,7 @@ class DataGenerator():
retval = self._process_batch(img_paths)
yield retval
def _get_images_with_meta(self, filenames: List[str]) -> Tuple[np.ndarray, List[DetectedFace]]:
def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]:
""" Obtain the raw face images with associated :class:`DetectedFace` objects for this
batch.
@ -253,9 +248,9 @@ class DataGenerator():
return raw_faces, detected_faces
def _crop_to_coverage(self,
filenames: List[str],
filenames: list[str],
images: np.ndarray,
detected_faces: List[DetectedFace],
detected_faces: list[DetectedFace],
batch: np.ndarray) -> None:
""" Crops the training image out of the full extract image based on the centering and
coveage used in the user's configuration settings.
@ -286,7 +281,7 @@ class DataGenerator():
for future in futures.as_completed(proc):
batch[proc[future], ..., :3] = future.result()
def _apply_mask(self, detected_faces: List[DetectedFace], batch: np.ndarray) -> None:
def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None:
""" Applies the masks to the 4th channel of the batch.
If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1
@ -312,7 +307,7 @@ class DataGenerator():
logger.trace("side: %s, masks: %s, batch: %s", # type: ignore
self._side, masks.shape, batch.shape)
def _process_batch(self, filenames: List[str]) -> BatchType:
def _process_batch(self, filenames: list[str]) -> BatchType:
""" Prepares data for feeding through subclassed methods.
If this is the first time a face has been loaded, then it's meta data is extracted from the
@ -345,9 +340,9 @@ class DataGenerator():
return feed, targets
def process_batch(self,
filenames: List[str],
filenames: list[str],
images: np.ndarray,
detected_faces: List[DetectedFace],
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Override for processing the batch for the current generator.
@ -391,7 +386,7 @@ class DataGenerator():
The input uint8 array
"""
return ne.evaluate("x / c",
local_dict=dict(x=in_array, c=np.float32(255)),
local_dict={"x": in_array, "c": np.float32(255)},
casting="unsafe")
@ -417,10 +412,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
objects of this size from the iterator.
"""
def __init__(self,
config: Dict[str, "ConfigValueType"],
model: "ModelBase",
side: Literal["a", "b"],
images: List[str],
config: dict[str, ConfigValueType],
model: ModelBase,
side: T.Literal["a", "b"],
images: list[str],
batch_size: int) -> None:
super().__init__(config, model, side, images, batch_size)
self._augment_color = not model.command_line_arguments.no_augment_color
@ -434,10 +429,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
self._processing = ImageAugmentation(batch_size,
self._process_size,
self._config)
self._nearest_landmarks: Dict[str, Tuple[str, ...]] = {}
self._nearest_landmarks: dict[str, tuple[str, ...]] = {}
logger.debug("Initialized %s", self.__class__.__name__)
def _create_targets(self, batch: np.ndarray) -> List[np.ndarray]:
def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]:
""" Compile target images, with masks, for the model output sizes.
Parameters
@ -467,9 +462,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return retval
def process_batch(self,
filenames: List[str],
filenames: list[str],
images: np.ndarray,
detected_faces: List[DetectedFace],
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Performs the augmentation and compiles target images and samples.
@ -525,7 +520,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
if self._warp_to_landmarks:
landmarks = np.array([face.aligned.landmarks for face in detected_faces])
batch_dst_pts = self._get_closest_match(filenames, landmarks)
warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts)
warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts}
else:
warp_kwargs = {}
@ -545,7 +540,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return feed, targets
def _get_closest_match(self, filenames: List[str], batch_src_points: np.ndarray) -> np.ndarray:
def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray:
""" Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest
matched 68 point landmarks from the opposite training set.
@ -563,7 +558,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
"""
logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore
"src_points: '%s')", filenames, batch_src_points)
lm_side: Literal["a", "b"] = "a" if self._side == "b" else "b"
lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b"
other_cache = get_cache(lm_side)
landmarks = other_cache.aligned_landmarks
@ -584,9 +579,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return batch_dst_points
def _cache_closest_matches(self,
filenames: List[str],
filenames: list[str],
batch_src_points: np.ndarray,
landmarks: Dict[str, np.ndarray]) -> List[Tuple[str, ...]]:
landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]:
""" Cache the nearest landmarks for this batch
Parameters
@ -602,7 +597,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
logger.trace("Caching closest matches") # type:ignore
dst_landmarks = list(landmarks.items())
dst_points = np.array([lm[1] for lm in dst_landmarks])
batch_closest_matches: List[Tuple[str, ...]] = []
batch_closest_matches: list[tuple[str, ...]] = []
for filename, src_points in zip(filenames, batch_src_points):
closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10]
@ -637,7 +632,7 @@ class PreviewDataGenerator(DataGenerator):
"""
def _create_samples(self,
images: np.ndarray,
detected_faces: List[DetectedFace]) -> List[np.ndarray]:
detected_faces: list[DetectedFace]) -> list[np.ndarray]:
""" Compile the 'sample' images. These are the 100% coverage images which hold the model
output in the preview window.
@ -658,11 +653,12 @@ class PreviewDataGenerator(DataGenerator):
output_size = self._output_sizes[-1]
full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2))
assert self._config["centering"] in get_args(CenteringType)
assert self._config["centering"] in T.get_args(CenteringType)
retval = np.empty((full_size, full_size, 3), dtype="float32")
retval = self._to_float32(np.array([AlignedFace(face.landmarks_xy,
retval = self._to_float32(np.array([
AlignedFace(face.landmarks_xy,
image=images[idx],
centering=cast(CenteringType,
centering=T.cast(CenteringType,
self._config["centering"]),
size=full_size,
dtype="uint8",
@ -673,9 +669,9 @@ class PreviewDataGenerator(DataGenerator):
return [retval]
def process_batch(self,
filenames: List[str],
filenames: list[str],
images: np.ndarray,
detected_faces: List[DetectedFace],
detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType:
""" Creates the full size preview images and the sub-cropped images for feeding the model's
predict function.

View file

@ -4,36 +4,30 @@
If Tkinter is installed, then this will be used to manage the preview image, otherwise we
fallback to opencv's imshow
"""
from __future__ import annotations
import logging
import sys
import typing as T
from threading import Event, Lock
from time import sleep
from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
import numpy as np
logger = logging.getLogger(__name__)
TriggerType = Dict[Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event]
TriggerKeysType = Literal["m", "r", "s", "enter"]
TriggerNamesType = Literal["toggle_mask", "refresh", "save", "quit"]
TriggerType = dict[T.Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event]
TriggerKeysType = T.Literal["m", "r", "s", "enter"]
TriggerNamesType = T.Literal["toggle_mask", "refresh", "save", "quit"]
class PreviewBuffer():
""" A thread safe class for holding preview images """
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
self._images: Dict[str, "np.ndarray"] = {}
self._images: dict[str, np.ndarray] = {}
self._lock = Lock()
self._updated = Event()
logger.debug("Initialized: %s", self.__class__.__name__)
@ -43,7 +37,7 @@ class PreviewBuffer():
""" bool: ``True`` when new images have been loaded into the preview buffer """
return self._updated.is_set()
def add_image(self, name: str, image: "np.ndarray") -> None:
def add_image(self, name: str, image: np.ndarray) -> None:
""" Add an image to the preview buffer in a thread safe way """
logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape)
with self._lock:
@ -51,7 +45,7 @@ class PreviewBuffer():
logger.debug("Added images: %s", list(self._images))
self._updated.set()
def get_images(self) -> Generator[Tuple[str, "np.ndarray"], None, None]:
def get_images(self) -> Generator[tuple[str, np.ndarray], None, None]:
""" Get the latest images from the preview buffer. When iterator is exhausted clears the
:attr:`updated` event.
@ -86,15 +80,15 @@ class PreviewBase(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
preview_buffer: PreviewBuffer,
triggers: Optional[TriggerType] = None) -> None:
triggers: TriggerType | None = None) -> None:
logger.debug("Initializing %s parent (triggers: %s)",
self.__class__.__name__, triggers)
self._triggers = triggers
self._buffer = preview_buffer
self._keymaps: Dict[TriggerKeysType, TriggerNamesType] = dict(m="toggle_mask",
r="refresh",
s="save",
enter="quit")
self._keymaps: dict[TriggerKeysType, TriggerNamesType] = {"m": "toggle_mask",
"r": "refresh",
"s": "save",
"enter": "quit"}
self._title = ""
logger.debug("Initialized %s parent", self.__class__.__name__)
@ -141,7 +135,7 @@ class PreviewCV(PreviewBase): # pylint:disable=too-few-public-methods
logger.debug("Unable to import Tkinter. Falling back to OpenCV")
super().__init__(preview_buffer, triggers=triggers)
self._triggers: TriggerType = self._triggers
self._windows: List[str] = []
self._windows: list[str] = []
self._lookup = {ord(key): val
for key, val in self._keymaps.items() if key != "enter"}

View file

@ -4,24 +4,25 @@
If Tkinter is installed, then this will be used to manage the preview image, otherwise we
fallback to opencv's imshow
"""
from __future__ import annotations
import logging
import os
import sys
import tkinter as tk
import typing as T
from datetime import datetime
from platform import system
from tkinter import ttk
from math import ceil, floor
from typing import cast, List, Optional, Tuple, TYPE_CHECKING
from PIL import Image, ImageTk
import cv2
from .preview_cv import PreviewBase, TriggerKeysType
if TYPE_CHECKING:
if T.TYPE_CHECKING:
import numpy as np
from .preview_cv import PreviewBuffer, TriggerType
@ -38,18 +39,18 @@ class _Taskbar():
taskbar: :class:`tkinter.ttk.Frame` or ``None``
None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI
"""
def __init__(self, parent: tk.Frame, taskbar: Optional[ttk.Frame]) -> None:
def __init__(self, parent: tk.Frame, taskbar: ttk.Frame | None) -> None:
logger.debug("Initializing %s (parent: '%s', taskbar: %s)",
self.__class__.__name__, parent, taskbar)
self._is_standalone = taskbar is None
self._gui_mapped: List[tk.Widget] = []
self._gui_mapped: list[tk.Widget] = []
self._frame = tk.Frame(parent) if taskbar is None else taskbar
self._min_max_scales = (20, 400)
self._vars = dict(save=tk.BooleanVar(),
scale=tk.StringVar(),
slider=tk.IntVar(),
interpolator=tk.IntVar())
self._vars = {"save": tk.BooleanVar(),
"scale": tk.StringVar(),
"slider": tk.IntVar(),
"interpolator": tk.IntVar()}
self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST),
("bicubic", cv2.INTER_CUBIC)]
self._scale = self._add_scale_combo()
@ -261,7 +262,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
def __init__(self,
parent: tk.Frame,
scale_var: tk.StringVar,
screen_dimensions: Tuple[int, int],
screen_dimensions: tuple[int, int],
is_standalone: bool) -> None:
logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)",
self.__class__.__name__, parent, scale_var, screen_dimensions)
@ -272,7 +273,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
self._screen_dimensions = screen_dimensions
self._var_scale = scale_var
self._configure_scrollbars(frame)
self._image: Optional[ImageTk.PhotoImage] = None
self._image: ImageTk.PhotoImage | None = None
self._image_id = self.create_image(self.width / 2,
self.height / 2,
anchor=tk.CENTER,
@ -400,8 +401,8 @@ class _Image():
logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)",
self.__class__.__name__, save_variable, is_standalone)
self._is_standalone = is_standalone
self._source: Optional["np.ndarray"] = None
self._display: Optional[ImageTk.PhotoImage] = None
self._source: np.ndarray | None = None
self._display: ImageTk.PhotoImage | None = None
self._scale = 1.0
self._interpolation = cv2.INTER_NEAREST
@ -416,7 +417,7 @@ class _Image():
return self._display
@property
def source(self) -> "np.ndarray":
def source(self) -> np.ndarray:
""" :class:`PIL.Image.Image`: The current source preview image """
assert self._source is not None
return self._source
@ -426,7 +427,7 @@ class _Image():
"""int: The current display scale as a percentage of original image size """
return int(self._scale * 100)
def set_source_image(self, name: str, image: "np.ndarray") -> None:
def set_source_image(self, name: str, image: np.ndarray) -> None:
""" Set the source image to :attr:`source`
Parameters
@ -542,7 +543,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
self._taskbar = taskbar
self._image = image
self._drag_data: List[float] = [0., 0.]
self._drag_data: list[float] = [0., 0.]
self._set_mouse_bindings()
self._set_key_bindings(is_standalone)
logger.debug("Initialized %s", self.__class__.__name__,)
@ -604,7 +605,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
The key press event
"""
move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview
visible = (move_axis()[1] - move_axis()[0])
visible = move_axis()[1] - move_axis()[0]
amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25
logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore
"amount: %s)", move_axis, visible, amount)
@ -671,10 +672,10 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
Default: `None`
"""
def __init__(self,
preview_buffer: "PreviewBuffer",
parent: Optional[tk.Widget] = None,
taskbar: Optional[ttk.Frame] = None,
triggers: Optional["TriggerType"] = None) -> None:
preview_buffer: PreviewBuffer,
parent: tk.Widget | None = None,
taskbar: ttk.Frame | None = None,
triggers: TriggerType | None = None) -> None:
logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent)
super().__init__(preview_buffer, triggers=triggers)
self._is_standalone = parent is None
@ -745,7 +746,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
logger.info(" Save Preview: Ctrl+s")
logger.info("---------------------------------------------------")
def _get_geometry(self) -> Tuple[int, int]:
def _get_geometry(self) -> tuple[int, int]:
""" Obtain the geometry of the current screen (standalone) or the dimensions of the widget
holding the preview window (GUI).
@ -780,7 +781,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
half_screen = tuple(x // 2 for x in self._screen_dimensions)
min_scales = (half_screen[0] / self._image.source.shape[1],
half_screen[1] / self._image.source.shape[0])
min_scale = min(1.0, min(min_scales))
min_scale = min(1.0, *min_scales)
min_scale = (ceil(min_scale * 10)) * 10
eight_screen = tuple(x * 8 for x in self._screen_dimensions)
@ -884,7 +885,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
if self._triggers is None: # Don't need triggers for GUI
return
keypress = "enter" if event.keysym == "Return" else event.keysym
key = cast(TriggerKeysType, keypress)
key = T.cast(TriggerKeysType, keypress)
logger.debug("Processing keypress '%s'", key)
if key == "r":
print("") # Let log print on different line from loss output

View file

@ -1,11 +1,12 @@
#!/usr/bin python3
""" Utilities available across all scripts """
from __future__ import annotations
import json
import logging
import os
import sys
import tkinter as tk
import typing as T
import warnings
import zipfile
@ -14,18 +15,12 @@ from re import finditer
from socket import timeout as socket_timeout, error as socket_error
from threading import get_ident
from time import time
from typing import cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from urllib import request, error as urlliberror
import numpy as np
from tqdm import tqdm
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from http.client import HTTPResponse
# Global variables
@ -34,8 +29,8 @@ _image_extensions = [ # pylint:disable=invalid-name
_video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"]
_TF_VERS: Optional[Tuple[int, int]] = None
ValidBackends = Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
_TF_VERS: tuple[int, int] | None = None
ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
class _Backend(): # pylint:disable=too-few-public-methods
@ -44,7 +39,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
If file doesn't exist and a variable hasn't been set, create the config file. """
def __init__(self) -> None:
self._backends: Dict[str, ValidBackends] = {"1": "cpu",
self._backends: dict[str, ValidBackends] = {"1": "cpu",
"2": "directml",
"3": "nvidia",
"4": "apple_silicon",
@ -78,9 +73,9 @@ class _Backend(): # pylint:disable=too-few-public-methods
"""
# Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ:
fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
assert fs_backend in get_args(ValidBackends), (
f"Faceswap backend must be one of {get_args(ValidBackends)}")
fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
assert fs_backend in T.get_args(ValidBackends), (
f"Faceswap backend must be one of {T.get_args(ValidBackends)}")
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
return fs_backend
# Intercept for sphinx docs build
@ -163,11 +158,11 @@ def set_backend(backend: str) -> None:
>>> set_backend("nvidia")
"""
global _FS_BACKEND # pylint:disable=global-statement
backend = cast(ValidBackends, backend.lower())
backend = T.cast(ValidBackends, backend.lower())
_FS_BACKEND = backend
def get_tf_version() -> Tuple[int, int]:
def get_tf_version() -> tuple[int, int]:
""" Obtain the major. minor version of currently installed Tensorflow.
Returns
@ -179,7 +174,7 @@ def get_tf_version() -> Tuple[int, int]:
-------
>>> from lib.utils import get_tf_version
>>> get_tf_version()
(2, 9)
(2, 10)
"""
global _TF_VERS # pylint:disable=global-statement
if _TF_VERS is None:
@ -225,7 +220,7 @@ def get_folder(path: str, make_folder: bool = True) -> str:
return path
def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]:
def get_image_paths(directory: str, extension: str | None = None) -> list[str]:
""" Gets the image paths from a given directory.
The function searches for files with the specified extension(s) in the given directory, and
@ -274,7 +269,7 @@ def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str
return dir_contents
def get_dpi() -> Optional[float]:
def get_dpi() -> float | None:
""" Gets the DPI (dots per inch) of the display screen.
Returns
@ -338,7 +333,7 @@ def convert_to_secs(*args: int) -> int:
return retval
def full_path_split(path: str) -> List[str]:
def full_path_split(path: str) -> list[str]:
""" Split a file path into all of its parts.
Parameters
@ -360,7 +355,7 @@ def full_path_split(path: str) -> List[str]:
['relative', 'path', 'to', 'file.txt']]
"""
logger = logging.getLogger(__name__)
allparts: List[str] = []
allparts: list[str] = []
while True:
parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths
@ -410,7 +405,7 @@ def set_system_verbosity(log_level: str):
warnings.simplefilter(action='ignore', category=warncat)
def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None:
def deprecation_warning(function: str, additional_info: str | None = None) -> None:
""" Log a deprecation warning message.
This function logs a warning message to indicate that the specified function has been
@ -436,7 +431,7 @@ def deprecation_warning(function: str, additional_info: Optional[str] = None) ->
logger.warning(msg)
def camel_case_split(identifier: str) -> List[str]:
def camel_case_split(identifier: str) -> list[str]:
""" Split a camelCase string into a list of its individual parts
Parameters
@ -541,7 +536,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
>>> model_downloader = GetModel("s3fd_keras_v2.h5", 11)
"""
def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None:
def __init__(self, model_filename: str | list[str], git_model_id: int) -> None:
self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list):
model_filename = [model_filename]
@ -576,7 +571,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
return retval
@property
def model_path(self) -> Union[str, List[str]]:
def model_path(self) -> str | list[str]:
""" str or list[str]: The model path(s) in the cache folder.
Example
@ -587,7 +582,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
'/path/to/s3fd_keras_v2.h5'
"""
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths
retval: str | list[str] = paths[0] if len(paths) == 1 else paths
self.logger.trace(retval) # type:ignore[attr-defined]
return retval
@ -662,7 +657,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
self._url_download, self._cache_dir)
sys.exit(1)
def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None:
def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None:
""" Write the model zip file to disk.
Parameters
@ -762,8 +757,8 @@ class DebugTimes():
"""
def __init__(self,
show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None:
self._times: Dict[str, List[float]] = {}
self._steps: Dict[str, float] = {}
self._times: dict[str, list[float]] = {}
self._steps: dict[str, float] = {}
self._interval = 1
self._display = {"min": show_min, "mean": show_mean, "max": show_max}

View file

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-06-11 23:20+0100\n"
"POT-Creation-Date: 2023-06-25 13:39+0100\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@ -227,7 +227,8 @@ msgid ""
msgstr ""
#: plugins/train/_config.py:198 plugins/train/_config.py:223
#: plugins/train/_config.py:238 plugins/train/_config.py:265
#: plugins/train/_config.py:238 plugins/train/_config.py:256
#: plugins/train/_config.py:290
msgid "optimizer"
msgstr ""
@ -274,21 +275,40 @@ msgid ""
"epsilon to 0.001 (1e-3)."
msgstr ""
#: plugins/train/_config.py:258
#: plugins/train/_config.py:262
msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the "
"gradient weights and adjusts the normalization value dynamically to fit the "
"data. Can help prevent NaNs and improve model optimization at the expense of "
"VRAM. Ref: AutoClip: Adaptive Gradient Clipping for Source Separation "
"Networks https://arxiv.org/abs/2007.14469"
"When to save the Optimizer Weights. Saving the optimizer weights is not "
"necessary and will increase the model file size 3x (and by extension the "
"amount of time it takes to save the model). However, it can be useful to "
"save these weights if you want to guarantee that a resumed model carries off "
"exactly from where it left off, rather than spending a few hundred "
"iterations catching up.\n"
"\t never - Don't save optimizer weights.\n"
"\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have "
"the last saved optimizer state in your model file.\n"
"\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."
msgstr ""
#: plugins/train/_config.py:271 plugins/train/_config.py:283
#: plugins/train/_config.py:297 plugins/train/_config.py:314
#: plugins/train/_config.py:283
msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
"and adjusts the normalization value dynamically to fit the data. Can help "
"prevent NaNs and improve model optimization at the expense of VRAM. Ref: "
"AutoClip: Adaptive Gradient Clipping for Source Separation Networks https://"
"arxiv.org/abs/2007.14469"
msgstr ""
#: plugins/train/_config.py:296 plugins/train/_config.py:308
#: plugins/train/_config.py:322 plugins/train/_config.py:339
msgid "network"
msgstr ""
#: plugins/train/_config.py:273
#: plugins/train/_config.py:298
msgid ""
"Use reflection padding rather than zero padding with convolutions. Each "
"convolution must pad the image boundaries to maintain the proper sizing. "
@ -297,21 +317,21 @@ msgid ""
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
msgstr ""
#: plugins/train/_config.py:286
#: plugins/train/_config.py:311
msgid ""
"Enable the Tensorflow GPU 'allow_growth' configuration "
"option. This option prevents Tensorflow from allocating all of the GPU VRAM "
"at launch but can lead to higher VRAM fragmentation and slower performance. "
"Should only be enabled if you are receiving errors regarding 'cuDNN fails to "
"initialize' when commencing training."
"Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
"prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
"lead to higher VRAM fragmentation and slower performance. Should only be "
"enabled if you are receiving errors regarding 'cuDNN fails to initialize' "
"when commencing training."
msgstr ""
#: plugins/train/_config.py:299
#: plugins/train/_config.py:324
msgid ""
"NVIDIA GPUs can run operations in float16 faster than in "
"float32. Mixed precision allows you to use a mix of float16 with float32, to "
"get the performance benefits from float16 and the numeric stability benefits "
"from float32.\n"
"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
"precision allows you to use a mix of float16 with float32, to get the "
"performance benefits from float16 and the numeric stability benefits from "
"float32.\n"
"\n"
"This is untested on DirectML backend, but will run on most Nvidia models. it "
"will only speed up training on more recent GPUs. Those with compute "
@ -322,7 +342,7 @@ msgid ""
"the most benefit."
msgstr ""
#: plugins/train/_config.py:316
#: plugins/train/_config.py:341
msgid ""
"If a 'NaN' is generated in the model, this means that the model has "
"corrupted and the model is likely to start deteriorating from this point on. "
@ -331,11 +351,11 @@ msgid ""
"rescue your model."
msgstr ""
#: plugins/train/_config.py:329
#: plugins/train/_config.py:354
msgid "convert"
msgstr ""
#: plugins/train/_config.py:331
#: plugins/train/_config.py:356
msgid ""
"[GPU Only]. The number of faces to feed through the model at once when "
"running the Convert process.\n"
@ -345,27 +365,27 @@ msgid ""
"size."
msgstr ""
#: plugins/train/_config.py:350
#: plugins/train/_config.py:375
msgid ""
"Loss configuration options\n"
"Loss is the mechanism by which a Neural Network judges how well it thinks "
"that it is recreating a face."
msgstr ""
#: plugins/train/_config.py:357 plugins/train/_config.py:369
#: plugins/train/_config.py:382 plugins/train/_config.py:402
#: plugins/train/_config.py:414 plugins/train/_config.py:434
#: plugins/train/_config.py:446 plugins/train/_config.py:466
#: plugins/train/_config.py:482 plugins/train/_config.py:498
#: plugins/train/_config.py:515
#: plugins/train/_config.py:382 plugins/train/_config.py:394
#: plugins/train/_config.py:407 plugins/train/_config.py:427
#: plugins/train/_config.py:439 plugins/train/_config.py:459
#: plugins/train/_config.py:471 plugins/train/_config.py:491
#: plugins/train/_config.py:507 plugins/train/_config.py:523
#: plugins/train/_config.py:540
msgid "loss"
msgstr ""
#: plugins/train/_config.py:361
#: plugins/train/_config.py:386
msgid "The loss function to use."
msgstr ""
#: plugins/train/_config.py:373
#: plugins/train/_config.py:398
msgid ""
"The second loss function to use. If using a structural based loss (such as "
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
@ -373,7 +393,7 @@ msgid ""
"function with the loss_weight_2 option."
msgstr ""
#: plugins/train/_config.py:388
#: plugins/train/_config.py:413
msgid ""
"The amount of weight to apply to the second loss function.\n"
"\n"
@ -391,13 +411,13 @@ msgid ""
"\t 0 - Disables the second loss function altogether."
msgstr ""
#: plugins/train/_config.py:406
#: plugins/train/_config.py:431
msgid ""
"The third loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option."
msgstr ""
#: plugins/train/_config.py:420
#: plugins/train/_config.py:445
msgid ""
"The amount of weight to apply to the third loss function.\n"
"\n"
@ -415,13 +435,13 @@ msgid ""
"\t 0 - Disables the third loss function altogether."
msgstr ""
#: plugins/train/_config.py:438
#: plugins/train/_config.py:463
msgid ""
"The fourth loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option."
msgstr ""
#: plugins/train/_config.py:452
#: plugins/train/_config.py:477
msgid ""
"The amount of weight to apply to the fourth loss function.\n"
"\n"
@ -439,7 +459,7 @@ msgid ""
"\t 0 - Disables the fourth loss function altogether."
msgstr ""
#: plugins/train/_config.py:471
#: plugins/train/_config.py:496
msgid ""
"The loss function to use when learning a mask.\n"
"\t MAE - Mean absolute error will guide reconstructions of each pixel "
@ -451,7 +471,7 @@ msgid ""
"susceptible to outliers and typically produces slightly blurrier results."
msgstr ""
#: plugins/train/_config.py:488
#: plugins/train/_config.py:513
msgid ""
"The amount of priority to give to the eyes.\n"
"\n"
@ -464,7 +484,7 @@ msgid ""
"NB: Penalized Mask Loss must be enable to use this option."
msgstr ""
#: plugins/train/_config.py:504
#: plugins/train/_config.py:529
msgid ""
"The amount of priority to give to the mouth.\n"
"\n"
@ -477,7 +497,7 @@ msgid ""
"NB: Penalized Mask Loss must be enable to use this option."
msgstr ""
#: plugins/train/_config.py:517
#: plugins/train/_config.py:542
msgid ""
"Image loss function is weighted by mask presence. For areas of the image "
"without the facial mask, reconstruction errors will be ignored while the "
@ -485,12 +505,12 @@ msgid ""
"attention on the core face area."
msgstr ""
#: plugins/train/_config.py:528 plugins/train/_config.py:570
#: plugins/train/_config.py:584 plugins/train/_config.py:593
#: plugins/train/_config.py:553 plugins/train/_config.py:595
#: plugins/train/_config.py:609 plugins/train/_config.py:618
msgid "mask"
msgstr ""
#: plugins/train/_config.py:531
#: plugins/train/_config.py:556
msgid ""
"The mask to be used for training. If you have selected 'Learn Mask' or "
"'Penalized Mask Loss' you must select a value other than 'none'. The "
@ -528,7 +548,7 @@ msgid ""
"performance."
msgstr ""
#: plugins/train/_config.py:572
#: plugins/train/_config.py:597
msgid ""
"Apply gaussian blur to the mask input. This has the effect of smoothing the "
"edges of the mask, which can help with poorly calculated masks and give less "
@ -538,13 +558,13 @@ msgid ""
"number."
msgstr ""
#: plugins/train/_config.py:586
#: plugins/train/_config.py:611
msgid ""
"Sets pixels that are near white to white and near black to black. Set to 0 "
"for off."
msgstr ""
#: plugins/train/_config.py:595
#: plugins/train/_config.py:620
msgid ""
"Dedicate a portion of the model to learning how to duplicate the input mask. "
"Increases VRAM usage in exchange for learning a quick ability to try to "

View file

@ -7,8 +7,8 @@ msgid ""
msgstr ""
"Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-06-11 23:20+0100\n"
"PO-Revision-Date: 2023-06-20 17:06+0100\n"
"POT-Creation-Date: 2023-06-25 13:39+0100\n"
"PO-Revision-Date: 2023-06-25 13:42+0100\n"
"Last-Translator: \n"
"Language-Team: \n"
"Language: ru_RU\n"
@ -354,7 +354,8 @@ msgstr ""
"повлияет только на запуск новой модели."
#: plugins/train/_config.py:198 plugins/train/_config.py:223
#: plugins/train/_config.py:238 plugins/train/_config.py:265
#: plugins/train/_config.py:238 plugins/train/_config.py:256
#: plugins/train/_config.py:290
msgid "optimizer"
msgstr "оптимизатор"
@ -435,7 +436,41 @@ msgstr ""
"Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе "
"значения \"-3\" эпсилон будет равен 0,001 (1e-3)."
#: plugins/train/_config.py:258
#: plugins/train/_config.py:262
msgid ""
"When to save the Optimizer Weights. Saving the optimizer weights is not "
"necessary and will increase the model file size 3x (and by extension the "
"amount of time it takes to save the model). However, it can be useful to "
"save these weights if you want to guarantee that a resumed model carries off "
"exactly from where it left off, rather than spending a few hundred "
"iterations catching up.\n"
"\t never - Don't save optimizer weights.\n"
"\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have "
"the last saved optimizer state in your model file.\n"
"\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."
msgstr ""
"Когда сохранять веса оптимизатора. Сохранение весов оптимизатора не является "
"необходимым и увеличит размер файла модели в 3 раза (и соответственно время, "
"необходимое для сохранения модели). Однако может быть полезно сохранить эти "
"веса, если вы хотите гарантировать, что возобновленная модель продолжит "
"работу именно с того места, где она остановилась, а не тратит несколько "
"сотен итераций на догонялки.\n"
"\t never - не сохранять веса оптимизатора.\n"
"\t always - сохранять веса оптимизатора при каждой итерации сохранения. "
"Сохранение модели займет больше времени из-за увеличенного размера файла, но "
"в файле модели всегда будет последнее сохраненное состояние оптимизатора.\n"
"\t exit - сохранять веса оптимизатора только при явном завершении модели. "
"Это может быть, когда модель активно останавливается или когда выполняются "
"целевые итерации. Примечание. Если сеанс обучения завершается по другой "
"причине (например, отключение питания, ошибка нехватки памяти, обнаружение "
"NaN), веса оптимизатора НЕ будут сохранены."
#: plugins/train/_config.py:283
msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
"and adjusts the normalization value dynamically to fit the data. Can help "
@ -449,12 +484,12 @@ msgstr ""
"ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping for Source "
"Separation Networks [ТОЛЬКО на английском] https://arxiv.org/abs/2007.14469"
#: plugins/train/_config.py:271 plugins/train/_config.py:283
#: plugins/train/_config.py:297 plugins/train/_config.py:314
#: plugins/train/_config.py:296 plugins/train/_config.py:308
#: plugins/train/_config.py:322 plugins/train/_config.py:339
msgid "network"
msgstr "сеть"
#: plugins/train/_config.py:273
#: plugins/train/_config.py:298
msgid ""
"Use reflection padding rather than zero padding with convolutions. Each "
"convolution must pad the image boundaries to maintain the proper sizing. "
@ -468,7 +503,7 @@ msgstr ""
"изображения.\n"
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
#: plugins/train/_config.py:286
#: plugins/train/_config.py:311
msgid ""
"Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
"prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
@ -483,7 +518,7 @@ msgstr ""
"случае, если у вас появляются ошибки, рода 'cuDNN fails to initialize'(cuDNN "
"не может инициализироваться) при начале тренировки."
#: plugins/train/_config.py:299
#: plugins/train/_config.py:324
msgid ""
"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
"precision allows you to use a mix of float16 with float32, to get the "
@ -512,7 +547,7 @@ msgstr ""
"ускорение. В основном RTX видеокарты и позже предлагают самое большое "
"ускорение."
#: plugins/train/_config.py:316
#: plugins/train/_config.py:341
msgid ""
"If a 'NaN' is generated in the model, this means that the model has "
"corrupted and the model is likely to start deteriorating from this point on. "
@ -526,11 +561,11 @@ msgstr ""
"NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет "
"возможность спасти вашу модель."
#: plugins/train/_config.py:329
#: plugins/train/_config.py:354
msgid "convert"
msgstr "конвертирование"
#: plugins/train/_config.py:331
#: plugins/train/_config.py:356
msgid ""
"[GPU Only]. The number of faces to feed through the model at once when "
"running the Convert process.\n"
@ -546,7 +581,7 @@ msgstr ""
"конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда "
"стоит снизить размер пачки."
#: plugins/train/_config.py:350
#: plugins/train/_config.py:375
msgid ""
"Loss configuration options\n"
"Loss is the mechanism by which a Neural Network judges how well it thinks "
@ -556,20 +591,20 @@ msgstr ""
"Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она "
"воспроизводит лицо."
#: plugins/train/_config.py:357 plugins/train/_config.py:369
#: plugins/train/_config.py:382 plugins/train/_config.py:402
#: plugins/train/_config.py:414 plugins/train/_config.py:434
#: plugins/train/_config.py:446 plugins/train/_config.py:466
#: plugins/train/_config.py:482 plugins/train/_config.py:498
#: plugins/train/_config.py:515
#: plugins/train/_config.py:382 plugins/train/_config.py:394
#: plugins/train/_config.py:407 plugins/train/_config.py:427
#: plugins/train/_config.py:439 plugins/train/_config.py:459
#: plugins/train/_config.py:471 plugins/train/_config.py:491
#: plugins/train/_config.py:507 plugins/train/_config.py:523
#: plugins/train/_config.py:540
msgid "loss"
msgstr "потери"
#: plugins/train/_config.py:361
#: plugins/train/_config.py:386
msgid "The loss function to use."
msgstr "Какую функцию потерь стоит использовать."
#: plugins/train/_config.py:373
#: plugins/train/_config.py:398
msgid ""
"The second loss function to use. If using a structural based loss (such as "
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
@ -581,7 +616,7 @@ msgstr ""
"регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес "
"этой функции потерь с помощью параметра loss_weight_2."
#: plugins/train/_config.py:388
#: plugins/train/_config.py:413
msgid ""
"The amount of weight to apply to the second loss function.\n"
"\n"
@ -612,7 +647,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:406
#: plugins/train/_config.py:431
msgid ""
"The third loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option."
@ -620,7 +655,7 @@ msgstr ""
"Третья используемая функция потерь. Вы можете настроить вес этой функции "
"потерь с помощью параметра loss_weight_3."
#: plugins/train/_config.py:420
#: plugins/train/_config.py:445
msgid ""
"The amount of weight to apply to the third loss function.\n"
"\n"
@ -651,7 +686,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:438
#: plugins/train/_config.py:463
msgid ""
"The fourth loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option."
@ -659,7 +694,7 @@ msgstr ""
"Четвертая используемая функция потерь. Вы можете настроить вес этой функции "
"потерь с помощью параметра 'loss_weight_4'."
#: plugins/train/_config.py:452
#: plugins/train/_config.py:477
msgid ""
"The amount of weight to apply to the fourth loss function.\n"
"\n"
@ -690,7 +725,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:471
#: plugins/train/_config.py:496
msgid ""
"The loss function to use when learning a mask.\n"
"\t MAE - Mean absolute error will guide reconstructions of each pixel "
@ -711,7 +746,7 @@ msgstr ""
"данных. Как среднее значение, оно чувствительно к выбросам и обычно дает "
"немного более размытые результаты."
#: plugins/train/_config.py:488
#: plugins/train/_config.py:513
msgid ""
"The amount of priority to give to the eyes.\n"
"\n"
@ -731,7 +766,7 @@ msgstr ""
"\n"
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
#: plugins/train/_config.py:504
#: plugins/train/_config.py:529
msgid ""
"The amount of priority to give to the mouth.\n"
"\n"
@ -751,7 +786,7 @@ msgstr ""
"\n"
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
#: plugins/train/_config.py:517
#: plugins/train/_config.py:542
msgid ""
"Image loss function is weighted by mask presence. For areas of the image "
"without the facial mask, reconstruction errors will be ignored while the "
@ -763,12 +798,12 @@ msgstr ""
"время как область лица с маской является приоритетной. Может повысить общее "
"качество за счет концентрации внимания на основной области лица."
#: plugins/train/_config.py:528 plugins/train/_config.py:570
#: plugins/train/_config.py:584 plugins/train/_config.py:593
#: plugins/train/_config.py:553 plugins/train/_config.py:595
#: plugins/train/_config.py:609 plugins/train/_config.py:618
msgid "mask"
msgstr "маска"
#: plugins/train/_config.py:531
#: plugins/train/_config.py:556
msgid ""
"The mask to be used for training. If you have selected 'Learn Mask' or "
"'Penalized Mask Loss' you must select a value other than 'none'. The "
@ -840,7 +875,7 @@ msgstr ""
"сообщества и для дальнейшего описания нуждается в тестировании. Профильные "
"лица могут иметь низкую производительность."
#: plugins/train/_config.py:572
#: plugins/train/_config.py:597
msgid ""
"Apply gaussian blur to the mask input. This has the effect of smoothing the "
"edges of the mask, which can help with poorly calculated masks and give less "
@ -856,7 +891,7 @@ msgstr ""
"должно быть нечетным, если передано четное число, то оно будет округлено до "
"следующего нечетного числа."
#: plugins/train/_config.py:586
#: plugins/train/_config.py:611
msgid ""
"Sets pixels that are near white to white and near black to black. Set to 0 "
"for off."
@ -864,7 +899,7 @@ msgstr ""
"Устанавливает пиксели, которые почти белые - в белые и которые почти черные "
"- в черные. Установите 0, чтобы выключить."
#: plugins/train/_config.py:595
#: plugins/train/_config.py:620
msgid ""
"Dedicate a portion of the model to learning how to duplicate the input mask. "
"Increases VRAM usage in exchange for learning a quick ability to try to "

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3
""" Plugin to blend the edges of the face between the swap and the original face. """
import logging
import sys
from typing import List, Optional, Tuple
import typing as T
import cv2
import numpy as np
@ -11,12 +10,6 @@ from lib.align import BlurMask, DetectedFace
from lib.config import FaceswapConfig
from plugins.convert._config import Config
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__)
@ -44,8 +37,8 @@ class Mask(): # pylint:disable=too-few-public-methods
mask_type: str,
output_size: int,
coverage_ratio: float,
configfile: Optional[str] = None,
config: Optional[FaceswapConfig] = None) -> None:
configfile: str | None = None,
config: FaceswapConfig | None = None) -> None:
logger.debug("Initializing %s: (mask_type: '%s', output_size: %s, coverage_ratio: %s, "
"configfile: %s, config: %s)", self.__class__.__name__, mask_type,
coverage_ratio, output_size, configfile, config)
@ -61,8 +54,8 @@ class Mask(): # pylint:disable=too-few-public-methods
self._do_erode = any(amount != 0 for amount in self._erodes)
def _set_config(self,
configfile: Optional[str],
config: Optional[FaceswapConfig]) -> dict:
configfile: str | None,
config: FaceswapConfig | None) -> dict:
""" Set the correct configuration for the plugin based on whether a config file
or pre-loaded config has been passed in.
@ -123,8 +116,8 @@ class Mask(): # pylint:disable=too-few-public-methods
detected_face: DetectedFace,
source_offset: np.ndarray,
target_offset: np.ndarray,
centering: Literal["legacy", "face", "head"],
predicted_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
centering: T.Literal["legacy", "face", "head"],
predicted_mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
""" Obtain the requested mask type and perform any defined mask manipulations.
Parameters
@ -171,8 +164,8 @@ class Mask(): # pylint:disable=too-few-public-methods
def _get_mask(self,
detected_face: DetectedFace,
predicted_mask: Optional[np.ndarray],
centering: Literal["legacy", "face", "head"],
predicted_mask: np.ndarray | None,
centering: T.Literal["legacy", "face", "head"],
source_offset: np.ndarray,
target_offset: np.ndarray) -> np.ndarray:
""" Return the requested mask with any requested blurring applied.
@ -229,7 +222,7 @@ class Mask(): # pylint:disable=too-few-public-methods
def _get_stored_mask(self,
detected_face: DetectedFace,
centering: Literal["legacy", "face", "head"],
centering: T.Literal["legacy", "face", "head"],
source_offset: np.ndarray,
target_offset: np.ndarray) -> np.ndarray:
""" get the requested stored mask from the detected face object.
@ -303,7 +296,7 @@ class Mask(): # pylint:disable=too-few-public-methods
return eroded[..., None]
def _get_erosion_kernels(self, mask: np.ndarray) -> List[np.ndarray]:
def _get_erosion_kernels(self, mask: np.ndarray) -> list[np.ndarray]:
""" Get the erosion kernels for each of the center, left, top right and bottom erosions.
An approximation is made based on the number of positive pixels within the mask to create

View file

@ -4,8 +4,7 @@
import logging
import os
import re
from typing import Any, List, Optional
import typing as T
import numpy as np
@ -14,7 +13,7 @@ from plugins.convert._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name: str, configfile: Optional[str] = None) -> dict:
def get_config(plugin_name: str, configfile: str | None = None) -> dict:
""" Obtain the configuration settings for the writer plugin.
Parameters
@ -44,7 +43,7 @@ class Output():
The full path to a custom configuration ini file. If ``None`` is passed
then the file is loaded from the default location. Default: ``None``.
"""
def __init__(self, output_folder: str, configfile: Optional[str] = None) -> None:
def __init__(self, output_folder: str, configfile: str | None = None) -> None:
logger.debug("Initializing %s: (output_folder: '%s')",
self.__class__.__name__, output_folder)
self.config: dict = get_config(".".join(self.__module__.split(".")[-2:]),
@ -69,7 +68,7 @@ class Output():
retval = hasattr(self, "frame_order")
return retval
def output_filename(self, filename: str, separate_mask: bool = False) -> List[str]:
def output_filename(self, filename: str, separate_mask: bool = False) -> list[str]:
""" Obtain the full path for the output file, including the correct extension, for the
given input filename.
@ -124,7 +123,7 @@ class Output():
logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore
logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore
def write(self, filename: str, image: Any) -> None:
def write(self, filename: str, image: T.Any) -> None:
""" Override for specific frame writing method.
Parameters
@ -137,7 +136,7 @@ class Output():
"""
raise NotImplementedError
def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument
def pre_encode(self, image: np.ndarray) -> T.Any: # pylint: disable=unused-argument
""" Some writer plugins support the pre-encoding of images prior to saving out. As
patching is done in multiple threads, but writing is done in a single thread, it can
speed up the process to do any pre-encoding as part of the converter process.

View file

@ -1,9 +1,11 @@
#!/usr/bin/env python3
""" Video output writer for faceswap.py converter """
from __future__ import annotations
import os
import typing as T
from math import ceil
from subprocess import CalledProcessError, check_output, STDOUT
from typing import cast, Generator, List, Optional, Tuple
import imageio
import imageio_ffmpeg as im_ffm
@ -11,6 +13,9 @@ import numpy as np
from ._base import Output, logger
if T.TYPE_CHECKING:
from collections.abc import Generator
class Writer(Output):
""" Video output writer using imageio-ffmpeg.
@ -32,7 +37,7 @@ class Writer(Output):
def __init__(self,
output_folder: str,
total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]],
frame_ranges: list[tuple[int, int]] | None,
source_video: str,
**kwargs) -> None:
super().__init__(output_folder, **kwargs)
@ -40,11 +45,11 @@ class Writer(Output):
total_count, frame_ranges, source_video)
self._source_video: str = source_video
self._output_filename: str = self._get_output_filename()
self._frame_ranges: Optional[List[Tuple[int, int]]] = frame_ranges
self.frame_order: List[int] = self._set_frame_order(total_count)
self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame
self._frame_ranges: list[tuple[int, int]] | None = frame_ranges
self.frame_order: list[int] = self._set_frame_order(total_count)
self._output_dimensions: str | None = None # Fix dims on 1st received frame
# Need to know dimensions of first frame, so set writer then
self._writer: Optional[Generator[None, np.ndarray, None]] = None
self._writer: Generator[None, np.ndarray, None] | None = None
@property
def _valid_tunes(self) -> dict:
@ -63,7 +68,7 @@ class Writer(Output):
return retval
@property
def _output_params(self) -> List[str]:
def _output_params(self) -> list[str]:
""" list: The FFMPEG Output parameters """
codec = self.config["codec"]
tune = self.config["tune"]
@ -86,11 +91,11 @@ class Writer(Output):
return output_args
@property
def _audio_codec(self) -> Optional[str]:
def _audio_codec(self) -> str | None:
""" str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default)
or ``None`` if skip muxing has been selected in configuration options, or if frame ranges
have been passed in the command line arguments. """
retval: Optional[str] = "copy"
retval: str | None = "copy"
if self.config["skip_mux"]:
logger.info("Skipping audio muxing due to configuration settings.")
retval = None
@ -169,7 +174,7 @@ class Writer(Output):
logger.info("Outputting to: '%s'", retval)
return retval
def _set_frame_order(self, total_count: int) -> List[int]:
def _set_frame_order(self, total_count: int) -> list[int]:
""" Obtain the full list of frames to be converted in order.
Parameters
@ -191,7 +196,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval)
return retval
def _get_writer(self, frame_dims: Tuple[int, int]) -> Generator[None, np.ndarray, None]:
def _get_writer(self, frame_dims: tuple[int, int]) -> Generator[None, np.ndarray, None]:
""" Add the requested encoding options and return the writer.
Parameters
@ -238,13 +243,13 @@ class Writer(Output):
logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined]
filename, image.shape)
if not self._output_dimensions:
input_dims = cast(Tuple[int, int], image.shape[:2])
input_dims = T.cast(tuple[int, int], image.shape[:2])
self._set_dimensions(input_dims)
self._writer = self._get_writer(input_dims)
self.cache_frame(filename, image)
self._save_from_cache()
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received.
This protects against different sized images coming in and ensures all images are written
to ffmpeg at the same size. Dimensions are mapped to a macro block size 8.

View file

@ -1,14 +1,15 @@
#!/usr/bin/env python3
""" Animated GIF writer for faceswap.py converter """
from __future__ import annotations
import os
from typing import Optional, List, Tuple, TYPE_CHECKING
import typing as T
import cv2
import imageio
from ._base import Output, logger
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from imageio.core import format as im_format # noqa:F401
@ -31,15 +32,16 @@ class Writer(Output):
def __init__(self,
output_folder: str,
total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]],
frame_ranges: list[tuple[int, int]] | None,
**kwargs) -> None:
logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges)
super().__init__(output_folder, **kwargs)
self.frame_order: List[int] = self._set_frame_order(total_count, frame_ranges)
self._output_dimensions: Optional[Tuple[int, int]] = None # Fix dims on 1st received frame
self.frame_order: list[int] = self._set_frame_order(total_count, frame_ranges)
# Fix dims on 1st received frame
self._output_dimensions: tuple[int, int] | None = None
# Need to know dimensions of first frame, so set writer then
self._writer: Optional[imageio.plugins.pillowmulti.GIFFormat.Writer] = None
self._gif_file: Optional[str] = None # Set filename based on first file seen
self._writer: imageio.plugins.pillowmulti.GIFFormat.Writer | None = None
self._gif_file: str | None = None # Set filename based on first file seen
@property
def _gif_params(self) -> dict:
@ -50,7 +52,7 @@ class Writer(Output):
@staticmethod
def _set_frame_order(total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]]) -> List[int]:
frame_ranges: list[tuple[int, int]] | None) -> list[int]:
""" Obtain the full list of frames to be converted in order.
Parameters
@ -75,7 +77,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval)
return retval
def _get_writer(self) -> "im_format.Format.Writer":
def _get_writer(self) -> im_format.Format.Writer:
""" Obtain the GIF writer with the requested GIF encoding options.
Returns
@ -145,7 +147,7 @@ class Writer(Output):
self._gif_file = retval
logger.info("Outputting to: '%s'", self._gif_file)
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received. This
protects against different sized images coming in and ensure all images get written to the
Gif at the sema dimensions. """

View file

@ -2,8 +2,6 @@
""" Image output writer for faceswap.py converter
Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO
"""
from typing import List, Tuple
import cv2
import numpy as np
@ -37,7 +35,7 @@ class Writer(Output):
"transparency. Changing output format to 'png'")
self.config["format"] = "png"
def _get_save_args(self) -> Tuple[int, ...]:
def _get_save_args(self) -> tuple[int, ...]:
""" Obtain the save parameters for the file format.
Returns
@ -46,7 +44,7 @@ class Writer(Output):
The OpenCV specific arguments for the selected file format
"""
filetype = self.config["format"]
args: Tuple[int, ...] = tuple()
args: tuple[int, ...] = tuple()
if filetype == "jpg" and self.config["jpg_quality"] > 0:
args = (cv2.IMWRITE_JPEG_QUALITY,
self.config["jpg_quality"])
@ -56,7 +54,7 @@ class Writer(Output):
logger.debug(args)
return args
def write(self, filename: str, image: List[bytes]) -> None:
def write(self, filename: str, image: list[bytes]) -> None:
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out
the encoded mask to a sub-folder in the output directory.
@ -77,7 +75,7 @@ class Writer(Output):
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
def pre_encode(self, image: np.ndarray) -> List[bytes]:
def pre_encode(self, image: np.ndarray) -> list[bytes]:
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker.
Parameters

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python3
""" Image output writer for faceswap.py converter """
from typing import Dict, List, Union
from io import BytesIO
from PIL import Image
@ -25,7 +23,7 @@ class Writer(Output):
super().__init__(output_folder, **kwargs)
self._check_transparency_format()
# Correct format namings for writing to byte stream
self._format_dict = dict(jpg="JPEG", jp2="JPEG 2000", tif="TIFF")
self._format_dict = {"jpg": "JPEG", "jp2": "JPEG 2000", "tif": "TIFF"}
self._separate_mask = self.config["draw_transparent"] and self.config["separate_mask"]
self._kwargs = self._get_save_kwargs()
@ -38,7 +36,7 @@ class Writer(Output):
"transparency. Changing output format to 'png'")
self.config["format"] = "png"
def _get_save_kwargs(self) -> Dict[str, Union[bool, int, str]]:
def _get_save_kwargs(self) -> dict[str, bool | int | str]:
""" Return the save parameters for the file format
Returns
@ -59,7 +57,7 @@ class Writer(Output):
logger.debug(kwargs)
return kwargs
def write(self, filename: str, image: List[BytesIO]) -> None:
def write(self, filename: str, image: list[BytesIO]) -> None:
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out
the encoded mask to a sub-folder in the output directory.
@ -80,7 +78,7 @@ class Writer(Output):
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
def pre_encode(self, image: np.ndarray) -> List[BytesIO]:
def pre_encode(self, image: np.ndarray) -> list[BytesIO]:
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker
Parameters

View file

@ -2,12 +2,11 @@
""" Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and
:mod:`~plugins.extract.mask` Plugins
"""
from __future__ import annotations
import logging
import sys
import typing as T
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Generator, List, Optional,
Sequence, Union, Tuple, TYPE_CHECKING)
import numpy as np
from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa
@ -18,12 +17,8 @@ from lib.utils import GetModel, FaceswapError
from ._config import Config
from .pipeline import ExtractMedia
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Callable, Generator, Sequence
from queue import Queue
import cv2
from lib.align import DetectedFace
@ -37,7 +32,7 @@ logger = logging.getLogger(__name__)
# TODO Run with warnings mode
def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str, Any]:
def _get_config(plugin_name: str, configfile: str | None = None) -> dict[str, T.Any]:
""" Return the configuration for the requested model
Parameters
@ -56,7 +51,7 @@ def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str,
return Config(plugin_name, configfile=configfile).config_dict
BatchType = Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"]
BatchType = T.Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"]
@dataclass
@ -84,13 +79,12 @@ class ExtractorBatch:
data: dict
Any specific data required during the processing phase for a particular plugin
"""
image: List[np.ndarray] = field(default_factory=list)
detected_faces: Sequence[Union["DetectedFace",
List["DetectedFace"]]] = field(default_factory=list)
filename: List[str] = field(default_factory=list)
image: list[np.ndarray] = field(default_factory=list)
detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list)
filename: list[str] = field(default_factory=list)
feed: np.ndarray = np.array([])
prediction: np.ndarray = np.array([])
data: List[Dict[str, Any]] = field(default_factory=list)
data: list[dict[str, T.Any]] = field(default_factory=list)
class Extractor():
@ -157,10 +151,10 @@ class Extractor():
"""
def __init__(self,
git_model_id: Optional[int] = None,
model_filename: Optional[Union[str, List[str]]] = None,
exclude_gpus: Optional[List[int]] = None,
configfile: Optional[str] = None,
git_model_id: int | None = None,
model_filename: str | list[str] | None = None,
exclude_gpus: list[int] | None = None,
configfile: str | None = None,
instance: int = 0) -> None:
logger.debug("Initializing %s: (git_model_id: %s, model_filename: %s, exclude_gpus: %s, "
"configfile: %s, instance: %s, )", self.__class__.__name__, git_model_id,
@ -176,9 +170,9 @@ class Extractor():
be a list of strings """
# << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> #
self.name: Optional[str] = None
self.name: str | None = None
self.input_size = 0
self.color_format: Literal["BGR", "RGB", "GRAY"] = "BGR"
self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR"
self.vram = 0
self.vram_warnings = 0 # Will run at this with warnings
self.vram_per_batch = 0
@ -187,7 +181,7 @@ class Extractor():
self.queue_size = 1
""" int: Queue size for all internal queues. Set in :func:`initialize()` """
self.model: Optional[Union["KSession", "cv2.dnn.Net"]] = None
self.model: KSession | cv2.dnn.Net | None = None
"""varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """
# For detectors that support batching, this should be set to the calculated batch size
@ -196,26 +190,26 @@ class Extractor():
""" int: Batchsize for feeding this model. The number of images the model should
feed through at once. """
self._queues: Dict[str, "Queue"] = {}
self._queues: dict[str, Queue] = {}
""" dict: in + out queues and internal queues for this plugin, """
self._threads: List[MultiThread] = []
self._threads: list[MultiThread] = []
""" list: Internal threads for this plugin """
self._extract_media: Dict[str, ExtractMedia] = {}
self._extract_media: dict[str, ExtractMedia] = {}
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
processed. Stored at input for pairing back up on output of extractor process """
# << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> #
self._plugin_type: Optional[Literal["align", "detect", "recognition", "mask"]] = None
self._plugin_type: T.Literal["align", "detect", "recognition", "mask"] | None = None
""" str: Plugin type. ``detect`, ``align``, ``recognise`` or ``mask`` set in
``<plugin_type>._base`` """
# << Objects for splitting frame's detected faces and rejoining them >>
# << for post-detector pliugins >>
self._faces_per_filename: Dict[str, int] = {} # Tracking for recompiling batches
self._rollover: Optional[ExtractMedia] = None # batch rollover items
self._output_faces: List["DetectedFace"] = [] # Recompiled output faces from plugin
self._faces_per_filename: dict[str, int] = {} # Tracking for recompiling batches
self._rollover: ExtractMedia | None = None # batch rollover items
self._output_faces: list[DetectedFace] = [] # Recompiled output faces from plugin
logger.debug("Initialized _base %s", self.__class__.__name__)
@ -361,7 +355,7 @@ class Extractor():
"""
raise NotImplementedError
def get_batch(self, queue: "Queue") -> Tuple[bool, BatchType]:
def get_batch(self, queue: Queue) -> tuple[bool, BatchType]:
""" **Override method** (at `<plugin_type>` level)
This method should be overridden at the `<plugin_type>` level (IE.
@ -403,7 +397,7 @@ class Extractor():
for thread in self._threads:
thread.check_and_raise_error()
def rollover_collector(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia]:
def rollover_collector(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia:
""" For extractors after the Detectors, the number of detected faces per frame vs extractor
batch size mean that faces will need to be split/re-joined with frames. The rollover
collector can be used to rollover items that don't fit in a batch.
@ -425,7 +419,7 @@ class Extractor():
if self._rollover is not None:
logger.trace("Getting from _rollover: (filename: `%s`, faces: %s)", # type:ignore
self._rollover.filename, len(self._rollover.detected_faces))
item: Union[Literal["EOF"], ExtractMedia] = self._rollover
item: T.Literal["EOF"] | ExtractMedia = self._rollover
self._rollover = None
else:
next_item = self._get_item(queue)
@ -442,9 +436,8 @@ class Extractor():
# <<< INIT METHODS >>> #
@classmethod
def _get_model(cls,
git_model_id: Optional[int],
model_filename: Optional[Union[str, List[str]]]
) -> Optional[Union[str, List[str]]]:
git_model_id: int | None,
model_filename: str | list[str] | None) -> str | list[str] | None:
""" Check if model is available, if not, download and unzip it """
if model_filename is None:
logger.debug("No model_filename specified. Returning None")
@ -496,9 +489,9 @@ class Extractor():
self.name, self._plugin_type.title(), self.batchsize)
def _add_queues(self,
in_queue: "Queue",
out_queue: "Queue",
queues: List[str]) -> None:
in_queue: Queue,
out_queue: Queue,
queues: list[str]) -> None:
""" Add the queues
in_queue and out_queue should be previously created queue manager queues.
queues should be a list of queue names """
@ -533,8 +526,8 @@ class Extractor():
def _add_thread(self,
name: str,
function: Callable[[BatchType], BatchType],
in_queue: "Queue",
out_queue: "Queue") -> None:
in_queue: Queue,
out_queue: Queue) -> None:
""" Add a MultiThread thread to self._threads """
logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)",
name, function, in_queue, out_queue)
@ -546,8 +539,8 @@ class Extractor():
logger.debug("Added thread: %s", name)
def _obtain_batch_item(self, function: Callable[[BatchType], BatchType],
in_queue: "Queue",
out_queue: "Queue") -> Optional[BatchType]:
in_queue: Queue,
out_queue: Queue) -> BatchType | None:
""" Obtain the batch item from the in queue for the current process.
Parameters
@ -564,7 +557,7 @@ class Extractor():
:class:`ExtractorBatch` or ``None``
The batch, if one exists, or ``None`` if queue is exhausted
"""
batch: Union[Literal["EOF"], BatchType, ExtractMedia]
batch: T.Literal["EOF"] | BatchType | ExtractMedia
if function.__name__ == "_process_input": # Process input items to batches
exhausted, batch = self.get_batch(in_queue)
if exhausted:
@ -585,8 +578,8 @@ class Extractor():
def _thread_process(self,
function: Callable[[BatchType], BatchType],
in_queue: "Queue",
out_queue: "Queue") -> None:
in_queue: Queue,
out_queue: Queue) -> None:
""" Perform a plugin function in a thread
Parameters
@ -629,7 +622,7 @@ class Extractor():
out_queue.put("EOF")
# <<< QUEUE METHODS >>> #
def _get_item(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia, BatchType]:
def _get_item(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia | BatchType:
""" Yield one item from a queue """
item = queue.get()
if isinstance(item, ExtractMedia):

View file

@ -12,12 +12,12 @@ For each source item, the plugin must pass a dict to finalize containing:
>>> "landmarks": [list of 68 point face landmarks]
>>> "detected_faces": [<list of DetectedFace objects>]}
"""
from __future__ import annotations
import logging
import sys
import typing as T
from dataclasses import dataclass, field
from time import sleep
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2
import numpy as np
@ -28,12 +28,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractMedia, ExtractorBatch
from .processing import AlignedFilter, ReAlign
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue
from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType
@ -77,9 +73,9 @@ class AlignerBatch(ExtractorBatch):
The masks used to filter out re-feed values for passing to the re-aligner.
"""
batch_id: int = 0
detected_faces: List["DetectedFace"] = field(default_factory=list)
detected_faces: list[DetectedFace] = field(default_factory=list)
landmarks: np.ndarray = np.array([])
refeeds: List[np.ndarray] = field(default_factory=list)
refeeds: list[np.ndarray] = field(default_factory=list)
second_pass: bool = False
second_pass_masks: np.ndarray = np.array([])
@ -142,11 +138,11 @@ class Aligner(Extractor): # pylint:disable=abstract-method
"""
def __init__(self,
git_model_id: Optional[int] = None,
model_filename: Optional[str] = None,
configfile: Optional[str] = None,
git_model_id: int | None = None,
model_filename: str | None = None,
configfile: str | None = None,
instance: int = 0,
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None,
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
re_feed: int = 0,
re_align: bool = False,
disable_filter: bool = False,
@ -160,9 +156,9 @@ class Aligner(Extractor): # pylint:disable=abstract-method
instance=instance,
**kwargs)
self._plugin_type = "align"
self.realign_centering: "CenteringType" = "face" # overide for plugin specific centering
self.realign_centering: CenteringType = "face" # overide for plugin specific centering
self._eof_seen = False
self._normalize_method: Optional[Literal["clahe", "hist", "mean"]] = None
self._normalize_method: T.Literal["clahe", "hist", "mean"] | None = None
self._re_feed = re_feed
self._filter = AlignedFilter(feature_filter=self.config["aligner_features"],
min_scale=self.config["aligner_min_scale"],
@ -181,8 +177,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
logger.debug("Initialized %s", self.__class__.__name__)
def set_normalize_method(self,
method: Optional[Literal["none", "clahe", "hist", "mean"]]) -> None:
def set_normalize_method(self, method: T.Literal["none", "clahe", "hist", "mean"] | None
) -> None:
""" Set the normalization method for feeding faces into the aligner.
Parameters
@ -191,14 +187,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method
The normalization method to apply to faces prior to feeding into the model
"""
method = None if method is None or method.lower() == "none" else method
self._normalize_method = cast(Optional[Literal["clahe", "hist", "mean"]], method)
self._normalize_method = T.cast(T.Literal["clahe", "hist", "mean"] | None, method)
def initialize(self, *args, **kwargs) -> None:
""" Add a call to add model input size to the re-aligner """
self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering)
super().initialize(*args, **kwargs)
def _handle_realigns(self, queue: "Queue") -> Optional[Tuple[bool, AlignerBatch]]:
def _handle_realigns(self, queue: Queue) -> tuple[bool, AlignerBatch] | None:
""" Handle any items waiting for a second pass through the aligner.
If EOF has been recieved and items are still being processed through the first pass
@ -242,7 +238,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return None
def get_batch(self, queue: "Queue") -> Tuple[bool, AlignerBatch]:
def get_batch(self, queue: Queue) -> tuple[bool, AlignerBatch]:
""" Get items for inputting into the aligner from the queue in batches
Items are returned from the ``queue`` in batches of
@ -548,7 +544,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
"\n3) Enable 'Single Process' mode.")
raise FaceswapError(msg) from err
def _process_refeeds(self, batch: AlignerBatch) -> List[AlignerBatch]:
def _process_refeeds(self, batch: AlignerBatch) -> list[AlignerBatch]:
""" Process the output for each selected re-feed
Parameters
@ -562,7 +558,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
List of :class:`AlignerBatch` objects. Each object in the list contains the
results for each selected re-feed
"""
retval: List[AlignerBatch] = []
retval: list[AlignerBatch] = []
if batch.second_pass:
# Re-insert empty sub-patches for re-population in ReAlign for filtered out batches
selected_idx = 0
@ -605,8 +601,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return retval
def _get_refeed_filter_masks(self,
subbatches: List[AlignerBatch],
original_masks: Optional[np.ndarray] = None) -> np.ndarray:
subbatches: list[AlignerBatch],
original_masks: np.ndarray | None = None) -> np.ndarray:
""" Obtain the boolean mask array for masking out failed re-feed results if filter refeed
has been selected
@ -663,7 +659,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
landmarks.shape)
return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32")
def _process_output_first_pass(self, subbatches: List[AlignerBatch]) -> Tuple[np.ndarray,
def _process_output_first_pass(self, subbatches: list[AlignerBatch]) -> tuple[np.ndarray,
np.ndarray]:
""" Process the output from the aligner if this is the first or only pass.
@ -696,7 +692,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return all_landmarks, masks
def _process_output_second_pass(self,
subbatches: List[AlignerBatch],
subbatches: list[AlignerBatch],
masks: np.ndarray) -> np.ndarray:
""" Process the output from the aligner if this is the first or only pass.

View file

@ -1,21 +1,16 @@
#!/usr/bin/env python3
""" Processing methods for aligner plugins """
from __future__ import annotations
import logging
import sys
import typing as T
from threading import Lock
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np
from lib.align import AlignedFace
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align import DetectedFace
from .aligner import AlignerBatch
from lib.align.aligned_face import CenteringType
@ -72,16 +67,16 @@ class AlignedFilter():
min_scale > 0.0 or
distance > 0.0 or
roll > 0.0)
self._counts: Dict[str, int] = dict(features=0,
min_scale=0,
max_scale=0,
distance=0,
roll=0)
self._counts: dict[str, int] = {"features": 0,
"min_scale": 0,
"max_scale": 0,
"distance": 0,
"roll": 0}
logger.debug("Initialized %s: ", self.__class__.__name__)
def _scale_test(self,
face: AlignedFace,
minimum_dimension: int) -> Optional[Literal["min", "max"]]:
minimum_dimension: int) -> T.Literal["min", "max"] | None:
""" Test if a face is below or above the min/max size thresholds. Returns as soon as a test
fails.
@ -116,9 +111,9 @@ class AlignedFilter():
def _handle_filtered(self,
key: str,
face: "DetectedFace",
faces: List["DetectedFace"],
sub_folders: List[Optional[str]],
face: DetectedFace,
faces: list[DetectedFace],
sub_folders: list[str | None],
sub_folder_index: int) -> None:
""" Add the filtered item to the filter counts.
@ -145,8 +140,8 @@ class AlignedFilter():
faces.append(face)
sub_folders[sub_folder_index] = f"_align_filt_{key}"
def __call__(self, faces: List["DetectedFace"], minimum_dimension: int
) -> Tuple[List["DetectedFace"], List[Optional[str]]]:
def __call__(self, faces: list[DetectedFace], minimum_dimension: int
) -> tuple[list[DetectedFace], list[str | None]]:
""" Apply the filter to the incoming batch
Parameters
@ -165,11 +160,11 @@ class AlignedFilter():
List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones``
and sub folder names corresponding the filtered face location
"""
sub_folders: List[Optional[str]] = [None for _ in range(len(faces))]
sub_folders: list[str | None] = [None for _ in range(len(faces))]
if not self._active:
return faces, sub_folders
retval: List["DetectedFace"] = []
retval: list[DetectedFace] = []
for idx, face in enumerate(faces):
aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face")
@ -194,8 +189,8 @@ class AlignedFilter():
return retval, sub_folders
def filtered_mask(self,
batch: "AlignerBatch",
skip: Optional[Union[np.ndarray, List[int]]] = None) -> np.ndarray:
batch: AlignerBatch,
skip: np.ndarray | list[int] | None = None) -> np.ndarray:
""" Obtain a list of boolean values for the given batch indicating whether they pass the
filter test.
@ -262,13 +257,14 @@ class ReAlign():
self._active = active
self._do_refeeds = do_refeeds
self._do_filter = do_filter
self._centering: "CenteringType" = "face"
self._centering: CenteringType = "face"
self._size = 0
self._tracked_lock = Lock()
self._tracked_batchs: Dict[int, Dict[Literal["filtered_landmarks"], List[np.ndarray]]] = {}
self._tracked_batchs: dict[int,
dict[T.Literal["filtered_landmarks"], list[np.ndarray]]] = {}
# TODO. Probably does not need to be a list, just alignerbatch
self._queue_lock = Lock()
self._queued: List["AlignerBatch"] = []
self._queued: list[AlignerBatch] = []
logger.debug("Initialized %s", self.__class__.__name__)
@property
@ -301,7 +297,7 @@ class ReAlign():
with self._tracked_lock:
return bool(self._tracked_batchs)
def set_input_size_and_centering(self, input_size: int, centering: "CenteringType") -> None:
def set_input_size_and_centering(self, input_size: int, centering: CenteringType) -> None:
""" Set the input size of the loaded plugin once the model has been loaded
Parameters
@ -344,7 +340,7 @@ class ReAlign():
with self._tracked_lock:
del self._tracked_batchs[batch_id]
def add_batch(self, batch: "AlignerBatch") -> None:
def add_batch(self, batch: AlignerBatch) -> None:
""" Add first pass alignments to the queue for picking up for re-alignment, update their
:attr:`second_pass` attribute to ``True`` and clear attributes not required.
@ -362,7 +358,7 @@ class ReAlign():
batch.data = []
self._queued.append(batch)
def get_batch(self) -> "AlignerBatch":
def get_batch(self) -> AlignerBatch:
""" Retrieve the next batch currently queued for re-alignment
Returns
@ -376,7 +372,7 @@ class ReAlign():
retval.filename)
return retval
def process_batch(self, batch: "AlignerBatch") -> List[np.ndarray]:
def process_batch(self, batch: AlignerBatch) -> list[np.ndarray]:
""" Pre process a batch object for re-aligning through the aligner.
Parameters
@ -391,8 +387,8 @@ class ReAlign():
"""
logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined]
batch.filename, [b.shape for b in batch.landmarks])
retval: List[np.ndarray] = []
filtered_landmarks: List[np.ndarray] = []
retval: list[np.ndarray] = []
filtered_landmarks: list[np.ndarray] = []
for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks):
if not np.all(masks): # At least one face has not already been filtered
aligned_faces = [AlignedFace(lms,
@ -415,7 +411,7 @@ class ReAlign():
batch.landmarks = np.array([]) # Clear the old landmarks
return retval
def _transform_to_frame(self, batch: "AlignerBatch") -> np.ndarray:
def _transform_to_frame(self, batch: AlignerBatch) -> np.ndarray:
""" Transform the predicted landmarks from the aligned face image back into frame
co-ordinates
@ -430,14 +426,14 @@ class ReAlign():
:class:`numpy.ndarray`
The landmarks transformed to frame space
"""
faces: List[AlignedFace] = batch.data[0]["aligned_faces"]
faces: list[AlignedFace] = batch.data[0]["aligned_faces"]
retval = np.array([aligned.transform_points(landmarks, invert=True)
for landmarks, aligned in zip(batch.landmarks, faces)])
logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined]
"new max: %s", batch.landmarks.max(), retval.max())
return retval
def _re_insert_filtered(self, batch: "AlignerBatch", masks: np.ndarray) -> np.ndarray:
def _re_insert_filtered(self, batch: AlignerBatch, masks: np.ndarray) -> np.ndarray:
""" Re-insert landmarks that were filtered out from the re-align process back into the
landmark results
@ -473,7 +469,7 @@ class ReAlign():
return retval
def process_output(self, subbatches: List["AlignerBatch"], batch_masks: np.ndarray) -> None:
def process_output(self, subbatches: list[AlignerBatch], batch_masks: np.ndarray) -> None:
""" Process the output from the re-align pass.
- Transform landmarks from aligned face space to face space

View file

@ -23,15 +23,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations
import logging
from typing import cast, List, Tuple, TYPE_CHECKING
import typing as T
import cv2
import numpy as np
from ._base import Aligner, AlignerBatch, BatchType
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.detected_face import DetectedFace
logger = logging.getLogger(__name__)
@ -89,9 +90,9 @@ class Align(Aligner):
assert isinstance(batch, AlignerBatch)
lfaces, roi, offsets = self.align_image(batch)
batch.feed = np.array(lfaces)[..., :3]
batch.data.append(dict(roi=roi, offsets=offsets))
batch.data.append({"roi": roi, "offsets": offsets})
def _get_box_and_offset(self, face: "DetectedFace") -> Tuple[List[int], int]:
def _get_box_and_offset(self, face: DetectedFace) -> tuple[list[int], int]:
"""Obtain the bounding box and offset from a detected face.
@ -108,17 +109,17 @@ class Align(Aligner):
The offset of the box (difference between half width vs height)
"""
box = cast(List[int], [face.left,
box = T.cast(list[int], [face.left,
face.top,
face.right,
face.bottom])
diff_height_width = cast(int, face.height) - cast(int, face.width)
diff_height_width = T.cast(int, face.height) - T.cast(int, face.width)
offset = int(abs(diff_height_width / 2))
return box, offset
def align_image(self, batch: AlignerBatch) -> Tuple[List[np.ndarray],
List[List[int]],
List[Tuple[int, int]]]:
def align_image(self, batch: AlignerBatch) -> tuple[list[np.ndarray],
list[list[int]],
list[tuple[int, int]]]:
""" Align the incoming image for prediction
Parameters
@ -159,8 +160,8 @@ class Align(Aligner):
@classmethod
def move_box(cls,
box: List[int],
offset: Tuple[int, int]) -> List[int]:
box: list[int],
offset: tuple[int, int]) -> list[int]:
"""Move the box to direction specified by vector offset
Parameters
@ -182,7 +183,7 @@ class Align(Aligner):
return [left, top, right, bottom]
@staticmethod
def get_square_box(box: List[int]) -> List[int]:
def get_square_box(box: list[int]) -> list[int]:
"""Get a square box out of the given box, by expanding it.
Parameters
@ -226,7 +227,7 @@ class Align(Aligner):
return [left, top, right, bottom]
@classmethod
def pad_image(cls, box: List[int], image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
def pad_image(cls, box: list[int], image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
"""Pad image if face-box falls outside of boundaries
Parameters

View file

@ -3,8 +3,9 @@
Code adapted and modified from:
https://github.com/1adrianb/face-alignment
"""
from __future__ import annotations
import logging
from typing import cast, List, TYPE_CHECKING
import typing as T
import cv2
import numpy as np
@ -12,7 +13,7 @@ import numpy as np
from lib.model.session import KSession
from ._base import Aligner, AlignerBatch, BatchType
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align import DetectedFace
logger = logging.getLogger(__name__)
@ -76,10 +77,10 @@ class Align(Aligner):
logger.trace("Aligning faces around center") # type:ignore[attr-defined]
center_scale = self.get_center_scale(batch.detected_faces)
batch.feed = np.array(self.crop(batch, center_scale))[..., :3]
batch.data.append(dict(center_scale=center_scale))
batch.data.append({"center_scale": center_scale})
logger.trace("Aligned image around center") # type:ignore[attr-defined]
def get_center_scale(self, detected_faces: List["DetectedFace"]) -> np.ndarray:
def get_center_scale(self, detected_faces: list[DetectedFace]) -> np.ndarray:
""" Get the center and set scale of bounding box
Parameters
@ -95,11 +96,11 @@ class Align(Aligner):
logger.trace("Calculating center and scale") # type:ignore[attr-defined]
center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32')
for index, face in enumerate(detected_faces):
x_center = (cast(int, face.left) + face.right) / 2.0
y_center = (cast(int, face.top) + face.bottom) / 2.0 - cast(int, face.height) * 0.12
scale = (cast(int, face.width) + cast(int, face.height)) * self.reference_scale
center_scale[index, :, 0] = np.full(68, x_center, dtype='float32')
center_scale[index, :, 1] = np.full(68, y_center, dtype='float32')
x_ctr = (T.cast(int, face.left) + face.right) / 2.0
y_ctr = (T.cast(int, face.top) + face.bottom) / 2.0 - T.cast(int, face.height) * 0.12
scale = (T.cast(int, face.width) + T.cast(int, face.height)) * self.reference_scale
center_scale[index, :, 0] = np.full(68, x_ctr, dtype='float32')
center_scale[index, :, 1] = np.full(68, y_ctr, dtype='float32')
center_scale[index, :, 2] = np.full(68, scale, dtype='float32')
logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined]
return center_scale
@ -144,7 +145,7 @@ class Align(Aligner):
dsize=(self.input_size, self.input_size),
interpolation=interp)
def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> List[np.ndarray]:
def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> list[np.ndarray]:
""" Crop image around the center point
Parameters

View file

@ -15,9 +15,11 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
>>> face = self._to_detected_face(<face left>, <face top>, <face right>, <face bottom>)
"""
from __future__ import annotations
import logging
import typing as T
from dataclasses import dataclass, field
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2
import numpy as np
@ -30,7 +32,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch
from plugins.extract.pipeline import ExtractMedia
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue
logger = logging.getLogger(__name__)
@ -53,10 +56,10 @@ class DetectorBatch(ExtractorBatch):
initial_feed: :class:`numpy.ndarray`
Used to hold the initial :attr:`feed` when rotate images is enabled
"""
detected_faces: List[List["DetectedFace"]] = field(default_factory=list)
rotation_matrix: List[np.ndarray] = field(default_factory=list)
scale: List[float] = field(default_factory=list)
pad: List[Tuple[int, int]] = field(default_factory=list)
detected_faces: list[list["DetectedFace"]] = field(default_factory=list)
rotation_matrix: list[np.ndarray] = field(default_factory=list)
scale: list[float] = field(default_factory=list)
pad: list[tuple[int, int]] = field(default_factory=list)
initial_feed: np.ndarray = np.array([])
@ -95,11 +98,11 @@ class Detector(Extractor): # pylint:disable=abstract-method
"""
def __init__(self,
git_model_id: Optional[int] = None,
model_filename: Optional[Union[str, List[str]]] = None,
configfile: Optional[str] = None,
git_model_id: int | None = None,
model_filename: str | list[str] | None = None,
configfile: str | None = None,
instance: int = 0,
rotation: Optional[str] = None,
rotation: str | None = None,
min_size: int = 0,
**kwargs) -> None:
logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__,
@ -117,7 +120,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.debug("Initialized _base %s", self.__class__.__name__)
# <<< QUEUE METHODS >>> #
def get_batch(self, queue: "Queue") -> Tuple[bool, DetectorBatch]:
def get_batch(self, queue: Queue) -> tuple[bool, DetectorBatch]:
""" Get items for inputting to the detector plugin in batches
Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted
@ -271,7 +274,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
""" Wrap models predict function in rotations """
assert isinstance(batch, DetectorBatch)
batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))]
found_faces: List[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))]
found_faces: list[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))]
for angle in self.rotation:
# Rotate the batch and insert placeholders for already found faces
self._rotate_batch(batch, angle)
@ -301,7 +304,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
"degrees",
angle)
found_faces = cast(List[np.ndarray], ([face if not found.any() else found
found_faces = T.cast(list[np.ndarray], ([face if not found.any() else found
for face, found in zip(batch.prediction,
found_faces)]))
@ -317,7 +320,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
# <<< DETECTION IMAGE COMPILATION METHODS >>> #
def _compile_detection_image(self, item: ExtractMedia
) -> Tuple[np.ndarray, float, Tuple[int, int]]:
) -> tuple[np.ndarray, float, tuple[int, int]]:
""" Compile the detection image for feeding into the model
Parameters
@ -345,7 +348,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
image.shape, scale, pad)
return image, scale, pad
def _set_scale(self, image_size: Tuple[int, int]) -> float:
def _set_scale(self, image_size: tuple[int, int]) -> float:
""" Set the scale factor for incoming image
Parameters
@ -362,7 +365,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined]
return scale
def _set_padding(self, image_size: Tuple[int, int], scale: float) -> Tuple[int, int]:
def _set_padding(self, image_size: tuple[int, int], scale: float) -> tuple[int, int]:
""" Set the image padding for non-square images
Parameters
@ -382,7 +385,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
return pad_left, pad_top
@staticmethod
def _scale_image(image: np.ndarray, image_size: Tuple[int, int], scale: float) -> np.ndarray:
def _scale_image(image: np.ndarray, image_size: tuple[int, int], scale: float) -> np.ndarray:
""" Scale the image and optional pad to given size
Parameters
@ -439,8 +442,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
return image
# <<< FINALIZE METHODS >>> #
def _remove_zero_sized_faces(self, batch_faces: List[List[DetectedFace]]
) -> List[List[DetectedFace]]:
def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]]
) -> list[list[DetectedFace]]:
""" Remove items from batch_faces where detected face is of zero size or face falls
entirely outside of image
@ -463,8 +466,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore
return retval
def _filter_small_faces(self, detected_faces: List[List[DetectedFace]]
) -> List[List[DetectedFace]]:
def _filter_small_faces(self, detected_faces: list[list[DetectedFace]]
) -> list[list[DetectedFace]]:
""" Filter out any faces smaller than the min size threshold
Parameters
@ -493,7 +496,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
# <<< IMAGE ROTATION METHODS >>> #
@staticmethod
def _get_rotation_angles(rotation: Optional[str]) -> List[int]:
def _get_rotation_angles(rotation: str | None) -> list[int]:
""" Set the rotation angles.
Parameters
@ -544,8 +547,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
batch.initial_feed = batch.feed.copy()
return
feeds: List[np.ndarray] = []
rotmats: List[np.ndarray] = []
feeds: list[np.ndarray] = []
rotmats: list[np.ndarray] = []
for img, faces, rotmat in zip(batch.initial_feed,
batch.prediction,
batch.rotation_matrix):
@ -605,7 +608,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
def _rotate_image_by_angle(self,
image: np.ndarray,
angle: int) -> Tuple[np.ndarray, np.ndarray]:
angle: int) -> tuple[np.ndarray, np.ndarray]:
""" Rotate an image by a given angle.
Parameters

View file

@ -34,7 +34,7 @@ class Detect(Detector):
self.kwargs = self._validate_kwargs()
self.color_format = "RGB"
def _validate_kwargs(self) -> T.Dict[str, T.Union[int, float, T.List[float]]]:
def _validate_kwargs(self) -> dict[str, int | float | list[float]]:
""" Validate that config options are correct. If not reset to default """
valid = True
threshold = [self.config["threshold_1"],
@ -164,7 +164,7 @@ class PNet(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
cpu_mode: bool,
input_size: int,
min_size: int,
@ -185,10 +185,10 @@ class PNet(KSession):
self._pnet_scales = self._calculate_scales(min_size, factor)
self._pnet_sizes = [(int(input_size * scale), int(input_size * scale))
for scale in self._pnet_scales]
self._pnet_input: T.Optional[T.List[np.ndarray]] = None
self._pnet_input: list[np.ndarray] | None = None
@staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras P-Network Definition for MTCNN """
input_ = Input(shape=(None, None, 3))
var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -204,7 +204,7 @@ class PNet(KSession):
def _calculate_scales(self,
minsize: int,
factor: float) -> T.List[float]:
factor: float) -> list[float]:
""" Calculate multi-scale
Parameters
@ -231,7 +231,7 @@ class PNet(KSession):
logger.trace(scales) # type:ignore
return scales
def __call__(self, images: np.ndarray) -> T.List[np.ndarray]:
def __call__(self, images: np.ndarray) -> list[np.ndarray]:
""" first stage - fast proposal network (p-net) to obtain face candidates
Parameters
@ -245,8 +245,8 @@ class PNet(KSession):
List of face candidates from P-Net
"""
batch_size = images.shape[0]
rectangles: T.List[T.List[T.List[T.Union[int, float]]]] = [[] for _ in range(batch_size)]
scores: T.List[T.List[np.ndarray]] = [[] for _ in range(batch_size)]
rectangles: list[list[list[int | float]]] = [[] for _ in range(batch_size)]
scores: list[list[np.ndarray]] = [[] for _ in range(batch_size)]
if self._pnet_input is None:
self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32")
@ -278,7 +278,7 @@ class PNet(KSession):
class_probabilities: np.ndarray,
roi: np.ndarray,
size: int,
scale: float) -> T.Tuple[np.ndarray, np.ndarray]:
scale: float) -> tuple[np.ndarray, np.ndarray]:
""" Detect face position and calibrate bounding box on 12net feature map(matrix version)
Parameters
@ -344,7 +344,7 @@ class RNet(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
cpu_mode: bool,
input_size: int,
threshold: float) -> None:
@ -360,7 +360,7 @@ class RNet(KSession):
self._threshold = threshold
@staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras R-Network Definition for MTCNN """
input_ = Input(shape=(24, 24, 3))
var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -383,8 +383,8 @@ class RNet(KSession):
def __call__(self,
images: np.ndarray,
rectangle_batch: T.List[np.ndarray],
) -> T.List[np.ndarray]:
rectangle_batch: list[np.ndarray],
) -> list[np.ndarray]:
""" second stage - refinement of face candidates with r-net
Parameters
@ -399,7 +399,7 @@ class RNet(KSession):
List
List of :class:`numpy.ndarray` refined face candidates from R-Net
"""
ret: T.List[np.ndarray] = []
ret: list[np.ndarray] = []
for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)):
if not np.any(rectangles):
ret.append(np.array([]))
@ -474,7 +474,7 @@ class ONet(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
cpu_mode: bool,
input_size: int,
threshold: float) -> None:
@ -490,7 +490,7 @@ class ONet(KSession):
self._threshold = threshold
@staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras O-Network for MTCNN """
input_ = Input(shape=(48, 48, 3))
var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -516,8 +516,8 @@ class ONet(KSession):
def __call__(self,
images: np.ndarray,
rectangle_batch: T.List[np.ndarray]
) -> T.List[T.Tuple[np.ndarray, np.ndarray]]:
rectangle_batch: list[np.ndarray]
) -> list[tuple[np.ndarray, np.ndarray]]:
""" Third stage - further refinement and facial landmarks positions with o-net
Parameters
@ -532,7 +532,7 @@ class ONet(KSession):
List
List of refined final candidates, scores and landmark points from O-Net
"""
ret: T.List[T.Tuple[np.ndarray, np.ndarray]] = []
ret: list[tuple[np.ndarray, np.ndarray]] = []
for idx, rectangles in enumerate(rectangle_batch):
if not np.any(rectangles):
ret.append((np.empty((0, 5)), np.empty(0)))
@ -552,7 +552,7 @@ class ONet(KSession):
def _filter_face_48net(self, class_probabilities: np.ndarray,
roi: np.ndarray,
points: np.ndarray,
rectangles: np.ndarray) -> T.Tuple[np.ndarray, np.ndarray]:
rectangles: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""" Filter face position and calibrate bounding box on 12net's output
Parameters
@ -623,13 +623,13 @@ class MTCNN(): # pylint: disable=too-few-public-methods
Default: `0.709`
"""
def __init__(self,
model_path: T.List[str],
model_path: list[str],
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
cpu_mode: bool,
input_size: int = 640,
minsize: int = 20,
threshold: T.Optional[T.List[float]] = None,
threshold: list[float] | None = None,
factor: float = 0.709) -> None:
logger.debug("Initializing: %s: (model_path: '%s', allow_growth: %s, exclude_gpus: %s, "
"input_size: %s, minsize: %s, threshold: %s, factor: %s)",
@ -660,7 +660,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
logger.debug("Initialized: %s", self.__class__.__name__)
def detect_faces(self, batch: np.ndarray) -> T.Tuple[np.ndarray, T.Tuple[np.ndarray]]:
def detect_faces(self, batch: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray]]:
"""Detects faces in an image, and returns bounding boxes and points for them.
Parameters
@ -684,7 +684,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
def nms(rectangles: np.ndarray,
scores: np.ndarray,
threshold: float,
method: str = "iom") -> T.Tuple[np.ndarray, np.ndarray]:
method: str = "iom") -> tuple[np.ndarray, np.ndarray]:
""" apply non-maximum suppression on ROIs in same scale(matrix version)
Parameters

View file

@ -125,10 +125,10 @@ class L2Norm(keras.layers.Layer):
class SliceO2K(keras.layers.Layer):
""" Custom Keras Slice layer generated by onnx2keras. """
def __init__(self,
starts: T.List[int],
ends: T.List[int],
axes: T.Optional[T.List[int]] = None,
steps: T.Optional[T.List[int]] = None,
starts: list[int],
ends: list[int],
axes: list[int] | None = None,
steps: list[int] | None = None,
**kwargs) -> None:
self._starts = starts
self._ends = ends
@ -136,7 +136,7 @@ class SliceO2K(keras.layers.Layer):
self._steps = steps
super().__init__(**kwargs)
def _get_slices(self, dimensions: int) -> T.List[T.Tuple[int, ...]]:
def _get_slices(self, dimensions: int) -> list[tuple[int, ...]]:
""" Obtain slices for the given number of dimensions.
Parameters
@ -154,7 +154,7 @@ class SliceO2K(keras.layers.Layer):
assert len(axes) == len(steps) == len(self._starts) == len(self._ends)
return list(zip(axes, self._starts, self._ends, steps))
def compute_output_shape(self, input_shape: T.Tuple[int, ...]) -> T.Tuple[int, ...]:
def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the output shape of the layer.
Assumes that the layer will be built to match that input shape provided.
@ -230,7 +230,7 @@ class S3fd(KSession):
model_path: str,
model_kwargs: dict,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
confidence: float) -> None:
logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, "
"exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path,
@ -246,7 +246,7 @@ class S3fd(KSession):
self.average_img = np.array([104.0, 117.0, 123.0])
logger.debug("Initialized: %s", self.__class__.__name__)
def model_definition(self) -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
def model_definition(self) -> tuple[list[Tensor], list[Tensor]]:
""" Keras S3FD Model Definition, adapted from FAN pytorch implementation. """
input_ = Input(shape=(640, 640, 3))
var_x = self.conv_block(input_, 64, 1, 2)
@ -396,7 +396,7 @@ class S3fd(KSession):
batch = batch - self.average_img
return batch
def finalize_predictions(self, bounding_boxes_scales: T.List[np.ndarray]) -> np.ndarray:
def finalize_predictions(self, bounding_boxes_scales: list[np.ndarray]) -> np.ndarray:
""" Process the output from the model to obtain faces
Parameters
@ -413,7 +413,7 @@ class S3fd(KSession):
ret.append(finallist)
return np.array(ret, dtype="object")
def _post_process(self, bboxlist: T.List[np.ndarray]) -> np.ndarray:
def _post_process(self, bboxlist: list[np.ndarray]) -> np.ndarray:
""" Perform post processing on output
TODO: do this on the batch.
"""

View file

@ -12,9 +12,11 @@ For each source item, the plugin must pass a dict to finalize containing:
>>> {"filename": <filename of source frame>,
>>> "detected_faces": <list of bounding box dicts from lib/plugins/extract/detect/_base>}
"""
from __future__ import annotations
import logging
import typing as T
from dataclasses import dataclass, field
from typing import Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2
import numpy as np
@ -25,7 +27,8 @@ from lib.align import AlignedFace, transform_image
from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue
from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType
@ -44,9 +47,9 @@ class MaskerBatch(ExtractorBatch):
roi_masks: list
The region of interest masks for the batch
"""
detected_faces: List["DetectedFace"] = field(default_factory=list)
roi_masks: List[np.ndarray] = field(default_factory=list)
feed_faces: List[AlignedFace] = field(default_factory=list)
detected_faces: list[DetectedFace] = field(default_factory=list)
roi_masks: list[np.ndarray] = field(default_factory=list)
feed_faces: list[AlignedFace] = field(default_factory=list)
class Masker(Extractor): # pylint:disable=abstract-method
@ -77,9 +80,9 @@ class Masker(Extractor): # pylint:disable=abstract-method
"""
def __init__(self,
git_model_id: Optional[int] = None,
model_filename: Optional[str] = None,
configfile: Optional[str] = None,
git_model_id: int | None = None,
model_filename: str | None = None,
configfile: str | None = None,
instance: int = 0,
**kwargs) -> None:
logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile)
@ -93,11 +96,11 @@ class Masker(Extractor): # pylint:disable=abstract-method
self._plugin_type = "mask"
self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-")
self._storage_centering: "CenteringType" = "face" # Centering to store the mask at
self._storage_centering: CenteringType = "face" # Centering to store the mask at
self._storage_size = 128 # Size to store masks at. Leave this at default
logger.debug("Initialized %s", self.__class__.__name__)
def get_batch(self, queue: "Queue") -> Tuple[bool, MaskerBatch]:
def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]:
""" Get items for inputting into the masker from the queue in batches
Items are returned from the ``queue`` in batches of

View file

@ -49,7 +49,7 @@ class Mask(Masker):
# Separate storage for face and head masks
self._storage_name = f"{self._storage_name}_{self._storage_centering}"
def _check_weights_selection(self, configfile: T.Optional[str]) -> T.Tuple[bool, int]:
def _check_weights_selection(self, configfile: str | None) -> tuple[bool, int]:
""" Check which weights have been selected.
This is required for passing along the correct file name for the corresponding weights
@ -73,7 +73,7 @@ class Mask(Masker):
version = 1 if not is_faceswap else 2 if config.get("include_hair") else 3
return is_faceswap, version
def _get_segment_indices(self) -> T.List[int]:
def _get_segment_indices(self) -> list[int]:
""" Obtain the segment indices to include within the face mask area based on user
configuration settings.
@ -163,7 +163,7 @@ class Mask(Masker):
# SOFTWARE.
_NAME_TRACKER: T.Set[str] = set()
_NAME_TRACKER: set[str] = set()
def _get_name(name: str, start_idx: int = 1) -> str:
@ -554,7 +554,7 @@ class BiSeNet(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]],
exclude_gpus: list[int] | None,
input_size: int,
num_classes: int,
cpu_mode: bool) -> None:
@ -569,7 +569,7 @@ class BiSeNet(KSession):
self.define_model(self._model_definition)
self.load_model_weights()
def _model_definition(self) -> T.Tuple[Tensor, T.List[Tensor]]:
def _model_definition(self) -> tuple[Tensor, list[Tensor]]:
""" Definition of the VGG Obstructed Model.
Returns

View file

@ -1,14 +1,15 @@
#!/usr/bin/env python3
""" Components Mask for faceswap.py """
from __future__ import annotations
import logging
from typing import List, Tuple, TYPE_CHECKING
import typing as T
import cv2
import numpy as np
from ._base import BatchType, Masker
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.aligned_face import AlignedFace
logger = logging.getLogger(__name__)
@ -36,7 +37,7 @@ class Mask(Masker):
def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """
faces: List["AlignedFace"] = feed[1]
faces: list[AlignedFace] = feed[1]
feed = feed[0]
for mask, face in zip(feed, faces):
parts = self.parse_parts(np.array(face.landmarks))
@ -51,7 +52,7 @@ class Mask(Masker):
return
@staticmethod
def parse_parts(landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
def parse_parts(landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
""" Component face hull mask """
r_jaw = (landmarks[0:9], landmarks[17:18])
l_jaw = (landmarks[8:17], landmarks[26:27])

View file

@ -1,7 +1,8 @@
#!/usr/bin/env python3
""" Extended Mask for faceswap.py """
from __future__ import annotations
import logging
from typing import List, Tuple, TYPE_CHECKING
import typing as T
import cv2
import numpy as np
@ -9,7 +10,7 @@ from ._base import BatchType, Masker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.aligned_face import AlignedFace
@ -35,7 +36,7 @@ class Mask(Masker):
def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """
faces: List["AlignedFace"] = feed[1]
faces: list[AlignedFace] = feed[1]
feed = feed[0]
for mask, face in zip(feed, faces):
parts = self.parse_parts(np.array(face.landmarks))
@ -78,7 +79,7 @@ class Mask(Masker):
landmarks[17:22] = top_l + ((top_l - bot_l) // 2)
landmarks[22:27] = top_r + ((top_r - bot_r) // 2)
def parse_parts(self, landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
def parse_parts(self, landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
""" Extended face hull mask """
self._adjust_mask_top(landmarks)

View file

@ -13,7 +13,7 @@ Model file sourced from...
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
"""
import logging
from typing import cast
import typing as T
import numpy as np
from lib.model.session import KSession
@ -52,7 +52,7 @@ class Mask(Masker):
def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """
assert isinstance(batch, MaskerBatch)
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3]
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
for feed in batch.feed_faces], dtype="float32") / 255.0
logger.trace("feed shape: %s", batch.feed.shape) # type: ignore

View file

@ -94,7 +94,7 @@ class VGGClear(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]]):
exclude_gpus: list[int] | None):
super().__init__("VGG Obstructed",
model_path,
allow_growth=allow_growth,
@ -103,7 +103,7 @@ class VGGClear(KSession):
self.load_model_weights()
@classmethod
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]:
def _model_definition(cls) -> tuple[Tensor, Tensor]:
""" Definition of the VGG Obstructed Model.
Returns
@ -210,7 +210,7 @@ class _ScorePool(): # pylint:disable=too-few-public-methods
crop: tuple
The amount of 2D cropping to apply. Tuple of `ints`
"""
def __init__(self, level: int, scale: float, crop: T.Tuple[int, int]):
def __init__(self, level: int, scale: float, crop: tuple[int, int]):
self._name = f"_pool{level}"
self._cropping = (crop, crop)
self._scale = scale

View file

@ -90,7 +90,7 @@ class VGGObstructed(KSession):
def __init__(self,
model_path: str,
allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]]) -> None:
exclude_gpus: list[int] | None) -> None:
super().__init__("VGG Obstructed",
model_path,
allow_growth=allow_growth,
@ -99,7 +99,7 @@ class VGGObstructed(KSession):
self.load_model_weights()
@classmethod
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]:
def _model_definition(cls) -> tuple[Tensor, Tensor]:
""" Definition of the VGG Obstructed Model.
Returns

View file

@ -8,10 +8,9 @@ together.
This module sets up a pipeline for the extraction workflow, loading detect, align and mask
plugins either in parallel or in series, giving easy access to input and output.
"""
from __future__ import annotations
import logging
import sys
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import typing as T
import cv2
@ -20,13 +19,9 @@ from lib.queue_manager import EventQueue, queue_manager, QueueEmpty
from lib.utils import get_backend
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
if T.TYPE_CHECKING:
import numpy as np
from collections.abc import Generator
from lib.align.alignments import PNGHeaderSourceDict
from lib.align.detected_face import DetectedFace
from plugins.extract._base import Extractor as PluginExtractor
@ -102,16 +97,16 @@ class Extractor():
:attr:`final_pass` to indicate to the caller which phase is being processed
"""
def __init__(self,
detector: Optional[str],
aligner: Optional[str],
masker: Optional[Union[str, List[str]]],
recognition: Optional[str] = None,
configfile: Optional[str] = None,
detector: str | None,
aligner: str | None,
masker: str | list[str] | None,
recognition: str | None = None,
configfile: str | None = None,
multiprocess: bool = False,
exclude_gpus: Optional[List[int]] = None,
rotate_images: Optional[str] = None,
exclude_gpus: list[int] | None = None,
rotate_images: str | None = None,
min_size: int = 0,
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None,
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
re_feed: int = 0,
re_align: bool = False,
disable_filter: bool = False) -> None:
@ -122,8 +117,9 @@ class Extractor():
recognition, configfile, multiprocess, exclude_gpus, rotate_images, min_size,
normalize_method, re_feed, re_align, disable_filter)
self._instance = _get_instance()
maskers = [cast(Optional[str],
masker)] if not isinstance(masker, list) else cast(List[Optional[str]], masker)
maskers = [T.cast(str | None,
masker)] if not isinstance(masker, list) else T.cast(list[str | None],
masker)
self._flow = self._set_flow(detector, aligner, maskers, recognition)
self._exclude_gpus = exclude_gpus
# We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting
@ -220,13 +216,13 @@ class Extractor():
return retval
@property
def aligner(self) -> "Aligner":
def aligner(self) -> Aligner:
""" The currently selected aligner plugin """
assert self._align is not None
return self._align
@property
def recognition(self) -> "Identity":
def recognition(self) -> Identity:
""" The currently selected recognition plugin """
assert self._recognition is not None
return self._recognition
@ -237,7 +233,7 @@ class Extractor():
self._phase_index = 0
def set_batchsize(self,
plugin_type: Literal["align", "detect"],
plugin_type: T.Literal["align", "detect"],
batchsize: int) -> None:
""" Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`.
@ -311,7 +307,7 @@ class Extractor():
# <<< INTERNAL METHODS >>> #
@property
def _parallel_scaling(self) -> Dict[int, float]:
def _parallel_scaling(self) -> dict[int, float]:
""" dict: key is number of parallel plugins being loaded, value is the scaling factor that
the total base vram for those plugins should be scaled by
@ -335,7 +331,7 @@ class Extractor():
return retval
@property
def _vram_per_phase(self) -> Dict[str, float]:
def _vram_per_phase(self) -> dict[str, float]:
""" dict: The amount of vram required for each phase in :attr:`_flow`. """
retval = {}
for phase in self._flow:
@ -359,7 +355,7 @@ class Extractor():
return retval
@property
def _current_phase(self) -> List[str]:
def _current_phase(self) -> list[str]:
""" list: The current phase from :attr:`_phases` that is running through the extractor. """
retval = self._phases[self._phase_index]
logger.trace(retval) # type: ignore
@ -384,7 +380,7 @@ class Extractor():
return retval
@property
def _all_plugins(self) -> List["PluginExtractor"]:
def _all_plugins(self) -> list[PluginExtractor]:
""" Return list of all plugin objects in this pipeline """
retval = []
for phase in self._flow:
@ -396,7 +392,7 @@ class Extractor():
return retval
@property
def _active_plugins(self) -> List["PluginExtractor"]:
def _active_plugins(self) -> list[PluginExtractor]:
""" Return the plugins that are currently active based on pass """
retval = []
for phase in self._current_phase:
@ -407,10 +403,10 @@ class Extractor():
return retval
@staticmethod
def _set_flow(detector: Optional[str],
aligner: Optional[str],
masker: List[Optional[str]],
recognition: Optional[str]) -> List[str]:
def _set_flow(detector: str | None,
aligner: str | None,
masker: list[str | None],
recognition: str | None) -> list[str]:
""" Set the flow list based on the input plugins
Parameters
@ -441,7 +437,7 @@ class Extractor():
return retval
@staticmethod
def _get_plugin_type_and_index(flow_phase: str) -> Tuple[str, Optional[int]]:
def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]:
""" Obtain the plugin type and index for the plugin for the given flow phase.
When multiple plugins for the same phase are allowed (e.g. Mask) this will return
@ -463,14 +459,14 @@ class Extractor():
"""
sidx = flow_phase.split("_")[-1]
if sidx.isdigit():
idx: Optional[int] = int(sidx)
idx: int | None = int(sidx)
plugin_type = "_".join(flow_phase.split("_")[:-1])
else:
plugin_type = flow_phase
idx = None
return plugin_type, idx
def _add_queues(self) -> Dict[str, EventQueue]:
def _add_queues(self) -> dict[str, EventQueue]:
""" Add the required processing queues to Queue Manager """
queues = {}
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
@ -483,7 +479,7 @@ class Extractor():
return queues
@staticmethod
def _get_vram_stats() -> Dict[str, Union[int, str]]:
def _get_vram_stats() -> dict[str, int | str]:
""" Obtain statistics on available VRAM and subtract a constant buffer from available vram.
Returns
@ -494,7 +490,7 @@ class Extractor():
vram_buffer = 256 # Leave a buffer for VRAM allocation
gpu_stats = GPUStats()
stats = gpu_stats.get_card_most_free()
retval: Dict[str, Union[int, str]] = {"count": gpu_stats.device_count,
retval: dict[str, int | str] = {"count": gpu_stats.device_count,
"device": stats.device,
"vram_free": int(stats.free - vram_buffer),
"vram_total": int(stats.total)}
@ -521,13 +517,13 @@ class Extractor():
self._vram_stats["device"],
self._vram_stats["vram_free"],
self._vram_stats["vram_total"])
if cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
logger.warning("Not enough free VRAM for parallel processing. "
"Switching to serial")
return False
return True
def _set_phases(self, multiprocess: bool) -> List[List[str]]:
def _set_phases(self, multiprocess: bool) -> list[list[str]]:
""" If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit
into VRAM, otherwise return the single flow.
@ -541,9 +537,9 @@ class Extractor():
list:
The jobs to be undertaken split into phases that fit into GPU RAM
"""
phases: List[List[str]] = []
current_phase: List[str] = []
available = cast(int, self._vram_stats["vram_free"])
phases: list[list[str]] = []
current_phase: list[str] = []
available = T.cast(int, self._vram_stats["vram_free"])
for phase in self._flow:
num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0])
num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0
@ -576,12 +572,12 @@ class Extractor():
# << INTERNAL PLUGIN HANDLING >> #
def _load_align(self,
aligner: Optional[str],
configfile: Optional[str],
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]],
aligner: str | None,
configfile: str | None,
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None,
re_feed: int,
re_align: bool,
disable_filter: bool) -> Optional["Aligner"]:
disable_filter: bool) -> Aligner | None:
""" Set global arguments and load aligner plugin
Parameters
@ -619,10 +615,10 @@ class Extractor():
return plugin
def _load_detect(self,
detector: Optional[str],
rotation: Optional[str],
detector: str | None,
rotation: str | None,
min_size: int,
configfile: Optional[str]) -> Optional["Detector"]:
configfile: str | None) -> Detector | None:
""" Set global arguments and load detector plugin """
if detector is None or detector.lower() == "none":
logger.debug("No detector selected. Returning None")
@ -637,8 +633,8 @@ class Extractor():
return plugin
def _load_mask(self,
masker: Optional[str],
configfile: Optional[str]) -> Optional["Masker"]:
masker: str | None,
configfile: str | None) -> Masker | None:
""" Set global arguments and load masker plugin
Parameters
@ -664,8 +660,8 @@ class Extractor():
return plugin
def _load_recognition(self,
recognition: Optional[str],
configfile: Optional[str]) -> Optional["Identity"]:
recognition: str | None,
configfile: str | None) -> Identity | None:
""" Set global arguments and load recognition plugin """
if recognition is None or recognition.lower() == "none":
logger.debug("No recognition selected. Returning None")
@ -716,16 +712,16 @@ class Extractor():
gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0]
scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback)
plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling
if plugins_required + batch_required <= cast(int, self._vram_stats["vram_free"]):
if plugins_required + batch_required <= T.cast(int, self._vram_stats["vram_free"]):
logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, "
"vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"])
return
# Hacky split across plugins that use vram
available_vram = (cast(int, self._vram_stats["vram_free"])
available_vram = (T.cast(int, self._vram_stats["vram_free"])
- plugins_required) // len(gpu_plugins)
self._set_plugin_batchsize(gpu_plugins, available_vram)
def _set_plugin_batchsize(self, gpu_plugins: List[str], available_vram: float) -> None:
def _set_plugin_batchsize(self, gpu_plugins: list[str], available_vram: float) -> None:
""" Set the batch size for the given plugin based on given available vram.
Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to
zero division error.
@ -802,20 +798,20 @@ class ExtractMedia():
def __init__(self,
filename: str,
image: "np.ndarray",
detected_faces: Optional[List["DetectedFace"]] = None,
image: np.ndarray,
detected_faces: list[DetectedFace] | None = None,
is_aligned: bool = False) -> None:
logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore
"detected_faces: %s, is_aligned: %s)", self.__class__.__name__, filename,
image.shape, detected_faces, is_aligned)
self._filename = filename
self._image: Optional["np.ndarray"] = image
self._image_shape = cast(Tuple[int, int, int], image.shape)
self._detected_faces: List["DetectedFace"] = ([] if detected_faces is None
self._image: np.ndarray | None = image
self._image_shape = T.cast(tuple[int, int, int], image.shape)
self._detected_faces: list[DetectedFace] = ([] if detected_faces is None
else detected_faces)
self._is_aligned = is_aligned
self._frame_metadata: Optional["PNGHeaderSourceDict"] = None
self._sub_folders: List[Optional[str]] = []
self._frame_metadata: PNGHeaderSourceDict | None = None
self._sub_folders: list[str | None] = []
@property
def filename(self) -> str:
@ -823,23 +819,23 @@ class ExtractMedia():
return self._filename
@property
def image(self) -> "np.ndarray":
def image(self) -> np.ndarray:
""" :class:`numpy.ndarray`: The source frame for this object. """
assert self._image is not None
return self._image
@property
def image_shape(self) -> Tuple[int, int, int]:
def image_shape(self) -> tuple[int, int, int]:
""" tuple: The shape of the stored :attr:`image`. """
return self._image_shape
@property
def image_size(self) -> Tuple[int, int]:
def image_size(self) -> tuple[int, int]:
""" tuple: The (`height`, `width`) of the stored :attr:`image`. """
return self._image_shape[:2]
@property
def detected_faces(self) -> List["DetectedFace"]:
def detected_faces(self) -> list[DetectedFace]:
"""list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """
return self._detected_faces
@ -849,7 +845,7 @@ class ExtractMedia():
return self._is_aligned
@property
def frame_metadata(self) -> "PNGHeaderSourceDict":
def frame_metadata(self) -> PNGHeaderSourceDict:
""" dict: The frame metadata that has been added from an aligned image. This property
should only be called after :func:`add_frame_metadata` has been called when processing
an aligned face. For all other instances an assertion error will be raised.
@ -863,13 +859,13 @@ class ExtractMedia():
return self._frame_metadata
@property
def sub_folders(self) -> List[Optional[str]]:
def sub_folders(self) -> list[str | None]:
""" list: The sub_folders that the faces should be output to. Used when binning filter
output is enabled. The list corresponds to the list of detected faces
"""
return self._sub_folders
def get_image_copy(self, color_format: Literal["BGR", "RGB", "GRAY"]) -> "np.ndarray":
def get_image_copy(self, color_format: T.Literal["BGR", "RGB", "GRAY"]) -> np.ndarray:
""" Get a copy of the image in the requested color format.
Parameters
@ -887,7 +883,7 @@ class ExtractMedia():
image = getattr(self, f"_image_as_{color_format.lower()}")()
return image
def add_detected_faces(self, faces: List["DetectedFace"]) -> None:
def add_detected_faces(self, faces: list[DetectedFace]) -> None:
""" Add detected faces to the object. Called at the end of each extraction phase.
Parameters
@ -900,7 +896,7 @@ class ExtractMedia():
[(face.left, face.right, face.top, face.bottom) for face in faces])
self._detected_faces = faces
def add_sub_folders(self, folders: List[Optional[str]]) -> None:
def add_sub_folders(self, folders: list[str | None]) -> None:
""" Add detected faces to the object. Called at the end of each extraction phase.
Parameters
@ -922,7 +918,7 @@ class ExtractMedia():
del self._image
self._image = None
def set_image(self, image: "np.ndarray") -> None:
def set_image(self, image: np.ndarray) -> None:
""" Add the image back into :attr:`image`
Required for multi-phase extraction adds the image back to this object.
@ -936,7 +932,7 @@ class ExtractMedia():
self._filename, image.shape)
self._image = image
def add_frame_metadata(self, metadata: "PNGHeaderSourceDict") -> None:
def add_frame_metadata(self, metadata: PNGHeaderSourceDict) -> None:
""" Add the source frame metadata from an aligned PNG's header data.
metadata: dict
@ -944,11 +940,11 @@ class ExtractMedia():
"""
logger.trace("Adding PNG Source data for '%s': %s", # type:ignore
self._filename, metadata)
dims = cast(Tuple[int, int], metadata["source_frame_dims"])
dims = T.cast(tuple[int, int], metadata["source_frame_dims"])
self._image_shape = (*dims, 3)
self._frame_metadata = metadata
def _image_as_bgr(self) -> "np.ndarray":
def _image_as_bgr(self) -> np.ndarray:
""" Get a copy of the source frame in BGR format.
Returns
@ -957,7 +953,7 @@ class ExtractMedia():
A copy of :attr:`image` in BGR color format """
return self.image[..., :3].copy()
def _image_as_rgb(self) -> "np.ndarray":
def _image_as_rgb(self) -> np.ndarray:
""" Get a copy of the source frame in RGB format.
Returns
@ -966,7 +962,7 @@ class ExtractMedia():
A copy of :attr:`image` in RGB color format """
return self.image[..., 2::-1].copy()
def _image_as_gray(self) -> "np.ndarray":
def _image_as_gray(self) -> np.ndarray:
""" Get a copy of the source frame in gray-scale format.
Returns

View file

@ -17,7 +17,6 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
"""
from __future__ import annotations
import logging
import sys
import typing as T
from dataclasses import dataclass, field
@ -31,13 +30,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch
from plugins.extract.pipeline import ExtractMedia
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue
from lib.align.aligned_face import CenteringType
@ -50,8 +44,8 @@ class RecogBatch(ExtractorBatch):
Inherits from :class:`~plugins.extract._base.ExtractorBatch`
"""
detected_faces: T.List["DetectedFace"] = field(default_factory=list)
feed_faces: T.List[AlignedFace] = field(default_factory=list)
detected_faces: list[DetectedFace] = field(default_factory=list)
feed_faces: list[AlignedFace] = field(default_factory=list)
class Identity(Extractor): # pylint:disable=abstract-method
@ -82,9 +76,9 @@ class Identity(Extractor): # pylint:disable=abstract-method
"""
def __init__(self,
git_model_id: T.Optional[int] = None,
model_filename: T.Optional[str] = None,
configfile: T.Optional[str] = None,
git_model_id: int | None = None,
model_filename: str | None = None,
configfile: str | None = None,
instance: int = 0,
**kwargs):
logger.debug("Initializing %s", self.__class__.__name__)
@ -119,7 +113,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
logger.debug("Obtained detected face: (filename: %s, detected_face: %s)",
item.filename, item.detected_faces)
def get_batch(self, queue: Queue) -> T.Tuple[bool, RecogBatch]:
def get_batch(self, queue: Queue) -> tuple[bool, RecogBatch]:
""" Get items for inputting into the recognition from the queue in batches
Items are returned from the ``queue`` in batches of
@ -226,7 +220,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
"\n3) Enable 'Single Process' mode.")
raise FaceswapError(msg) from err
def finalize(self, batch: BatchType) -> T.Generator[ExtractMedia, None, None]:
def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]:
""" Finalize the output from Masker
This should be called as the final task of each `plugin`.
@ -301,8 +295,8 @@ class IdentityFilter():
def __init__(self, save_output: bool) -> None:
logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output)
self._save_output = save_output
self._filter: T.Optional[np.ndarray] = None
self._nfilter: T.Optional[np.ndarray] = None
self._filter: np.ndarray | None = None
self._nfilter: np.ndarray | None = None
self._threshold = 0.0
self._filter_enabled: bool = False
self._nfilter_enabled: bool = False
@ -357,7 +351,7 @@ class IdentityFilter():
return retval
def _get_matches(self,
filter_type: Literal["filter", "nfilter"],
filter_type: T.Literal["filter", "nfilter"],
identities: np.ndarray) -> np.ndarray:
""" Obtain the average and minimum distances for each face against the source identities
to test against
@ -386,9 +380,9 @@ class IdentityFilter():
return retval
def _filter_faces(self,
faces: T.List[DetectedFace],
sub_folders: T.List[T.Optional[str]],
should_filter: T.List[bool]) -> T.List[DetectedFace]:
faces: list[DetectedFace],
sub_folders: list[str | None],
should_filter: list[bool]) -> list[DetectedFace]:
""" Filter the detected faces, either removing filtered faces from the list of detected
faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if
saving output is enabled.
@ -410,7 +404,7 @@ class IdentityFilter():
The filtered list of detected face objects, if saving filtered faces has not been
selected or the full list of detected faces
"""
retval: T.List[DetectedFace] = []
retval: list[DetectedFace] = []
self._counts += sum(should_filter)
for idx, face in enumerate(faces):
fldr = sub_folders[idx]
@ -429,8 +423,8 @@ class IdentityFilter():
return retval
def __call__(self,
faces: T.List[DetectedFace],
sub_folders: T.List[T.Optional[str]]) -> T.List[DetectedFace]:
faces: list[DetectedFace],
sub_folders: list[str | None]) -> list[DetectedFace]:
""" Call the identity filter function
Parameters
@ -459,14 +453,14 @@ class IdentityFilter():
logger.trace("All faces already filtered: %s", sub_folders) # type: ignore
return faces
should_filter: T.List[np.ndarray] = []
for f_type in get_args(Literal["filter", "nfilter"]):
should_filter: list[np.ndarray] = []
for f_type in T.get_args(T.Literal["filter", "nfilter"]):
if not getattr(self, f"_{f_type}_enabled"):
continue
should_filter.append(self._get_matches(f_type, identities))
# If any of the filter or nfilter evaluate to 'should filter' then filter out face
final_filter: T.List[bool] = np.array(should_filter).max(axis=0).tolist()
final_filter: list[bool] = np.array(should_filter).max(axis=0).tolist()
logger.trace("should_filter: %s, final_filter: %s", # type: ignore
should_filter, final_filter)
return self._filter_faces(faces, sub_folders, final_filter)

View file

@ -1,10 +1,9 @@
#!/usr/bin python3
""" VGG_Face2 inference and sorting """
from __future__ import annotations
import logging
import sys
from typing import cast, Dict, Generator, List, Tuple, Optional
import typing as T
import numpy as np
import psutil
@ -15,11 +14,8 @@ from lib.model.session import KSession
from lib.utils import FaceswapError
from ._base import BatchType, RecogBatch, Identity
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from collections.abc import Generator
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -64,7 +60,7 @@ class Recognition(Identity):
def init_model(self) -> None:
""" Initialize VGG Face 2 Model. """
assert isinstance(self.model_path, str)
model_kwargs = dict(custom_objects={'L2_normalize': L2_normalize})
model_kwargs = {"custom_objects": {"L2_normalize": L2_normalize}}
self.model = KSession(self.name,
self.model_path,
model_kwargs=model_kwargs,
@ -76,7 +72,7 @@ class Recognition(Identity):
def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """
assert isinstance(batch, RecogBatch)
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3]
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
for feed in batch.feed_faces],
dtype="float32") - self._average_img
logger.trace("feed shape: %s", batch.feed.shape) # type:ignore
@ -121,15 +117,15 @@ class Cluster(): # pylint: disable=too-few-public-methods
def __init__(self,
predictions: np.ndarray,
method: Literal["single", "centroid", "median", "ward"],
threshold: Optional[float] = None) -> None:
method: T.Literal["single", "centroid", "median", "ward"],
threshold: float | None = None) -> None:
logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)",
self.__class__.__name__, predictions.shape, method, threshold)
self._num_predictions = predictions.shape[0]
self._should_output_bins = threshold is not None
self._threshold = 0.0 if threshold is None else threshold
self._bins: Dict[int, int] = {}
self._bins: dict[int, int] = {}
self._iterator = self._integer_iterator()
self._result_linkage = self._do_linkage(predictions, method)
@ -192,7 +188,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
def _do_linkage(self,
predictions: np.ndarray,
method: Literal["single", "centroid", "median", "ward"]) -> np.ndarray:
method: T.Literal["single", "centroid", "median", "ward"]) -> np.ndarray:
""" Use FastCluster to perform vector or standard linkage
Parameters
@ -218,7 +214,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
def _process_leaf_node(self,
current_index: int,
current_bin: int) -> List[Tuple[int, int]]:
current_bin: int) -> list[tuple[int, int]]:
""" Process the output when we have hit a leaf node """
if not self._should_output_bins:
return [(current_index, 0)]
@ -263,7 +259,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
tree: np.ndarray,
points: int,
current_index: int,
current_bin: int = 0) -> List[Tuple[int, int]]:
current_bin: int = 0) -> list[tuple[int, int]]:
""" Seriation method for sorted similarity.
Seriation computes the order implied by a hierarchical tree (dendrogram).
@ -298,7 +294,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
return serate_left + serate_right # type: ignore
def __call__(self) -> List[Tuple[int, int]]:
def __call__(self) -> list[tuple[int, int]]:
""" Process the linkages.
Transforms a distance matrix into a sorted distance matrix according to the order implied

View file

@ -1,13 +1,14 @@
#!/usr/bin/env python3
""" Plugin loader for Faceswap extract, training and convert tasks """
from __future__ import annotations
import logging
import os
import sys
from importlib import import_module
from typing import Callable, List, Type, TYPE_CHECKING
import typing as T
if TYPE_CHECKING:
from importlib import import_module
if T.TYPE_CHECKING:
from collections.abc import Callable
from plugins.extract.detect._base import Detector
from plugins.extract.align._base import Aligner
from plugins.extract.mask._base import Masker
@ -15,11 +16,6 @@ if TYPE_CHECKING:
from plugins.train.model._base import ModelBase
from plugins.train.trainer._base import TrainerBase
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -36,7 +32,7 @@ class PluginLoader():
>>> aligner = PluginLoader.get_aligner('cv2-dnn')
"""
@staticmethod
def get_detector(name: str, disable_logging: bool = False) -> Type["Detector"]:
def get_detector(name: str, disable_logging: bool = False) -> type[Detector]:
""" Return requested detector plugin
Parameters
@ -55,7 +51,7 @@ class PluginLoader():
return PluginLoader._import("extract.detect", name, disable_logging)
@staticmethod
def get_aligner(name: str, disable_logging: bool = False) -> Type["Aligner"]:
def get_aligner(name: str, disable_logging: bool = False) -> type[Aligner]:
""" Return requested aligner plugin
Parameters
@ -74,7 +70,7 @@ class PluginLoader():
return PluginLoader._import("extract.align", name, disable_logging)
@staticmethod
def get_masker(name: str, disable_logging: bool = False) -> Type["Masker"]:
def get_masker(name: str, disable_logging: bool = False) -> type[Masker]:
""" Return requested masker plugin
Parameters
@ -93,7 +89,7 @@ class PluginLoader():
return PluginLoader._import("extract.mask", name, disable_logging)
@staticmethod
def get_recognition(name: str, disable_logging: bool = False) -> Type["Identity"]:
def get_recognition(name: str, disable_logging: bool = False) -> type[Identity]:
""" Return requested recognition plugin
Parameters
@ -112,7 +108,7 @@ class PluginLoader():
return PluginLoader._import("extract.recognition", name, disable_logging)
@staticmethod
def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]:
def get_model(name: str, disable_logging: bool = False) -> type[ModelBase]:
""" Return requested training model plugin
Parameters
@ -131,7 +127,7 @@ class PluginLoader():
return PluginLoader._import("train.model", name, disable_logging)
@staticmethod
def get_trainer(name: str, disable_logging: bool = False) -> Type["TrainerBase"]:
def get_trainer(name: str, disable_logging: bool = False) -> type[TrainerBase]:
""" Return requested training trainer plugin
Parameters
@ -198,9 +194,9 @@ class PluginLoader():
return getattr(module, ttl)
@staticmethod
def get_available_extractors(extractor_type: Literal["align", "detect", "mask"],
def get_available_extractors(extractor_type: T.Literal["align", "detect", "mask"],
add_none: bool = False,
extend_plugin: bool = False) -> List[str]:
extend_plugin: bool = False) -> list[str]:
""" Return a list of available extractors of the given type
Parameters
@ -243,7 +239,7 @@ class PluginLoader():
return extractors
@staticmethod
def get_available_models() -> List[str]:
def get_available_models() -> list[str]:
""" Return a list of available training models
Returns
@ -273,7 +269,7 @@ class PluginLoader():
return 'original' if 'original' in models else models[0]
@staticmethod
def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> List[str]:
def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> list[str]:
""" Return a list of available converter plugins in the given category
Parameters

View file

@ -249,6 +249,31 @@ class Config(FaceswapConfig):
"NB: The value given here is the 'exponent' to the epsilon. For example, "
"choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the epsilon "
"to 0.001 (1e-3)."))
self.add_item(
section=section,
title="save_optimizer",
datatype=str,
group=_("optimizer"),
default="exit",
fixed=False,
gui_radio=True,
choices=["never", "always", "exit"],
info=_(
"When to save the Optimizer Weights. Saving the optimizer weights is not "
"necessary and will increase the model file size 3x (and by extension the amount "
"of time it takes to save the model). However, it can be useful to save these "
"weights if you want to guarantee that a resumed model carries off exactly from "
"where it left off, rather than spending a few hundred iterations catching up."
"\n\t never - Don't save optimizer weights."
"\n\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have the "
"last saved optimizer state in your model file."
"\n\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."))
self.add_item(
section=section,
title="autoclip",

View file

@ -21,11 +21,6 @@ from tensorflow.keras.models import load_model, Model as KModel # noqa:E501 #
from lib.model.backup_restore import Backup
from lib.utils import FaceswapError
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from tensorflow import keras
from .model import ModelBase
@ -35,7 +30,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_all_sub_models(
model: keras.models.Model,
models: T.Optional[T.List[keras.models.Model]] = None) -> T.List[keras.models.Model]:
models: list[keras.models.Model] | None = None) -> list[keras.models.Model]:
""" For a given model, return all sub-models that occur (recursively) as children.
Parameters
@ -85,12 +80,12 @@ class IO():
plugin: ModelBase,
model_dir: str,
is_predict: bool,
save_optimizer: Literal["never", "always", "exit"]) -> None:
save_optimizer: T.Literal["never", "always", "exit"]) -> None:
self._plugin = plugin
self._is_predict = is_predict
self._model_dir = model_dir
self._save_optimizer = save_optimizer
self._history: T.List[T.List[float]] = [[], []] # Loss histories per save iteration
self._history: list[list[float]] = [[], []] # Loss histories per save iteration
self._backup = Backup(self._model_dir, self._plugin.name)
@property
@ -106,12 +101,12 @@ class IO():
return os.path.isfile(self._filename)
@property
def history(self) -> T.List[T.List[float]]:
def history(self) -> list[list[float]]:
""" list: list of loss histories per side for the current save iteration. """
return self._history
@property
def multiple_models_in_folder(self) -> T.Optional[T.List[str]]:
def multiple_models_in_folder(self) -> list[str] | None:
""" :list: or ``None`` If there are multiple model types in the requested folder, or model
types that don't correspond to the requested plugin type, then returns the list of plugin
names that exist in the folder, otherwise returns ``None`` """
@ -210,7 +205,7 @@ class IO():
msg += f" - Average loss since last save: {', '.join(lossmsg)}"
logger.info(msg)
def _get_save_averages(self) -> T.List[float]:
def _get_save_averages(self) -> list[float]:
""" Return the average loss since the last save iteration and reset historical loss """
logger.debug("Getting save averages")
if not all(loss for loss in self._history):
@ -222,7 +217,7 @@ class IO():
logger.debug("Average losses since last save: %s", retval)
return retval
def _should_backup(self, save_averages: T.List[float]) -> bool:
def _should_backup(self, save_averages: list[float]) -> bool:
""" Check whether the loss averages for this save iteration is the lowest that has been
seen.
@ -301,7 +296,7 @@ class Weights():
logger.debug("Initialized %s", self.__class__.__name__)
@classmethod
def _check_weights_file(cls, weights_file: str) -> T.Optional[str]:
def _check_weights_file(cls, weights_file: str) -> str | None:
""" Validate that we have a valid path to a .h5 file.
Parameters
@ -403,7 +398,7 @@ class Weights():
"different settings than you have set for your current model.",
skipped_ops)
def _get_weights_model(self) -> T.List[keras.models.Model]:
def _get_weights_model(self) -> list[keras.models.Model]:
""" Obtain a list of all sub-models contained within the weights model.
Returns
@ -429,7 +424,7 @@ class Weights():
def _load_layer_weights(self,
layer: keras.layers.Layer,
sub_weights: keras.layers.Layer,
model_name: str) -> Literal[-1, 0, 1]:
model_name: str) -> T.Literal[-1, 0, 1]:
""" Load the weights for a single layer.
Parameters

View file

@ -29,18 +29,12 @@ from plugins.train._config import Config
from .io import IO, get_all_sub_models, Weights
from .settings import Loss, Optimizer, Settings
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
import argparse
from lib.config import ConfigValueType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG: T.Dict[str, ConfigValueType] = {}
_CONFIG: dict[str, ConfigValueType] = {}
class ModelBase():
@ -79,13 +73,13 @@ class ModelBase():
self.__class__.__name__, model_dir, arguments, predict)
# Input shape must be set within the plugin after initializing
self.input_shape: T.Tuple[int, ...] = ()
self.input_shape: tuple[int, ...] = ()
self.trainer = "original" # Override for plugin specific trainer
self.color_order: Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
self._args = arguments
self._is_predict = predict
self._model: T.Optional[tf.keras.models.Model] = None
self._model: tf.keras.models.Model | None = None
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
self._load_config()
@ -100,14 +94,7 @@ class ModelBase():
"use. Please select a mask or disable 'Learn Mask'.")
self._mixed_precision = self.config["mixed_precision"]
# self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
# TODO - Re-enable saving of optimizer once this bug is fixed:
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
# File "h5py/h5d.pyx", line 87, in h5py.h5d.create
# ValueError: Unable to create dataset (name already exists)
self._io = IO(self, model_dir, self._is_predict, "never")
self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
self._check_multiple_models()
self._state = State(model_dir,
@ -175,16 +162,16 @@ class ModelBase():
return self.name
@property
def input_shapes(self) -> T.List[T.Tuple[None, int, int, int]]:
def input_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the inputs to the model. """
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(inputs))
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(inputs))
for inputs in self.model.inputs]
return shapes
@property
def output_shapes(self) -> T.List[T.Tuple[None, int, int, int]]:
def output_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the outputs of the model. """
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(output))
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(output))
for output in self.model.outputs]
return shapes
@ -333,7 +320,7 @@ class ModelBase():
a list of 2 shape tuples of 3 dimensions. """
assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple"
def _get_inputs(self) -> T.List[tf.keras.layers.Input]:
def _get_inputs(self) -> list[tf.keras.layers.Input]:
""" Obtain the standardized inputs for the model.
The inputs will be returned for the "A" and "B" sides in the shape as defined by
@ -352,7 +339,7 @@ class ModelBase():
logger.debug("inputs: %s", inputs)
return inputs
def build_model(self, inputs: T.List[tf.keras.layers.Input]) -> tf.keras.models.Model:
def build_model(self, inputs: list[tf.keras.layers.Input]) -> tf.keras.models.Model:
""" Override for Model Specific autoencoder builds.
Parameters
@ -427,7 +414,7 @@ class ModelBase():
self._state.add_session_loss_names(self._loss.names)
logger.debug("Compiled Model: %s", self.model)
def _legacy_mapping(self) -> T.Optional[dict]:
def _legacy_mapping(self) -> dict | None:
""" The mapping of separate model files to single model layers for transferring of legacy
weights.
@ -439,7 +426,7 @@ class ModelBase():
"""
return None
def add_history(self, loss: T.List[float]) -> None:
def add_history(self, loss: list[float]) -> None:
""" Add the current iteration's loss history to :attr:`_io.history`.
Called from the trainer after each iteration, for tracking loss drop over time between
@ -482,18 +469,18 @@ class State():
self._filename = os.path.join(model_dir, filename)
self._name = model_name
self._iterations = 0
self._mixed_precision_layers: T.List[str] = []
self._mixed_precision_layers: list[str] = []
self._rebuild_model = False
self._sessions: T.Dict[int, dict] = {}
self._lowest_avg_loss: T.Dict[str, float] = {}
self._config: T.Dict[str, ConfigValueType] = {}
self._sessions: dict[int, dict] = {}
self._lowest_avg_loss: dict[str, float] = {}
self._config: dict[str, ConfigValueType] = {}
self._load(config_changeable_items)
self._session_id = self._new_session_id()
self._create_new_session(no_logs, config_changeable_items)
logger.debug("Initialized %s:", self.__class__.__name__)
@property
def loss_names(self) -> T.List[str]:
def loss_names(self) -> list[str]:
""" list: The loss names for the current session """
return self._sessions[self._session_id]["loss_names"]
@ -518,7 +505,7 @@ class State():
return self._session_id
@property
def mixed_precision_layers(self) -> T.List[str]:
def mixed_precision_layers(self) -> list[str]:
"""list: Layers that can be switched between mixed-float16 and float32. """
return self._mixed_precision_layers
@ -564,7 +551,7 @@ class State():
"iterations": 0,
"config": config_changeable_items}
def add_session_loss_names(self, loss_names: T.List[str]) -> None:
def add_session_loss_names(self, loss_names: list[str]) -> None:
""" Add the session loss names to the sessions dictionary.
The loss names are used for Tensorboard logging
@ -593,7 +580,7 @@ class State():
self._iterations += 1
self._sessions[self._session_id]["iterations"] += 1
def add_mixed_precision_layers(self, layers: T.List[str]) -> None:
def add_mixed_precision_layers(self, layers: list[str]) -> None:
""" Add the list of model's layers that are compatible for mixed precision to the
state dictionary """
logger.debug("Storing mixed precision layers: %s", layers)
@ -655,7 +642,7 @@ class State():
legacy_update = self._update_legacy_config()
# Add any new items to state config for legacy purposes where the new default may be
# detrimental to an existing model.
legacy_defaults: T.Dict[str, T.Union[str, int, bool]] = {"centering": "legacy",
legacy_defaults: dict[str, str | int | bool] = {"centering": "legacy",
"mask_loss_function": "mse",
"l2_reg_term": 100,
"optimizer": "adam",
@ -807,7 +794,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
""" :class:`keras.models.Model`: The Faceswap model, compiled for inference. """
return self._model
def _get_nodes(self, nodes: np.ndarray) -> T.List[T.Tuple[str, int]]:
def _get_nodes(self, nodes: np.ndarray) -> list[tuple[str, int]]:
""" Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary,
filters the layer name(s) and output index of the node, splitting to the correct output
index in the event of multiple inputs.
@ -849,7 +836,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
logger.debug("Compiling inference model. saved_model: %s", saved_model)
struct = self._get_filtered_structure()
model_inputs = self._get_inputs(saved_model.inputs)
compiled_layers: T.Dict[str, tf.keras.layers.Layer] = {}
compiled_layers: dict[str, tf.keras.layers.Layer] = {}
for layer in saved_model.layers:
if layer.name not in struct:
logger.debug("Skipping unused layer: '%s'", layer.name)

View file

@ -14,7 +14,6 @@ from __future__ import annotations
from dataclasses import dataclass, field
import logging
import platform
import sys
import typing as T
from contextlib import nullcontext
@ -28,12 +27,9 @@ from lib.model import losses, optimizers
from lib.model.autoclip import AutoClipper
from lib.utils import get_backend
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from collections.abc import Callable
from contextlib import AbstractContextManager as ContextManager
from argparse import Namespace
from .model import State
@ -58,9 +54,9 @@ class LossClass:
kwargs: dict
Any keyword arguments to supply to the loss function at initialization.
"""
function: T.Union[T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor], T.Any] = k_losses.mae
function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | T.Any = k_losses.mae
init: bool = True
kwargs: T.Dict[str, T.Any] = field(default_factory=dict)
kwargs: dict[str, T.Any] = field(default_factory=dict)
class Loss():
@ -73,13 +69,13 @@ class Loss():
color_order: str
Color order of the model. One of `"BGR"` or `"RGB"`
"""
def __init__(self, config: dict, color_order: Literal["bgr", "rgb"]) -> None:
def __init__(self, config: dict, color_order: T.Literal["bgr", "rgb"]) -> None:
logger.debug("Initializing %s: (color_order: %s)", self.__class__.__name__, color_order)
self._config = config
self._mask_channels = self._get_mask_channels()
self._inputs: T.List[tf.keras.layers.Layer] = []
self._names: T.List[str] = []
self._funcs: T.Dict[str, T.Callable] = {}
self._inputs: list[tf.keras.layers.Layer] = []
self._names: list[str] = []
self._funcs: dict[str, Callable] = {}
self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss),
"flip": LossClass(function=losses.LDRFLIPLoss,
@ -104,7 +100,7 @@ class Loss():
logger.debug("Initialized: %s", self.__class__.__name__)
@property
def names(self) -> T.List[str]:
def names(self) -> list[str]:
""" list: The list of loss names for the model. """
return self._names
@ -114,14 +110,14 @@ class Loss():
return self._funcs
@property
def _mask_inputs(self) -> T.Optional[list]:
def _mask_inputs(self) -> list | None:
""" list: The list of input tensors to the model that contain the mask. Returns ``None``
if there is no mask input to the model. """
mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")]
return None if not mask_inputs else mask_inputs
@property
def _mask_shapes(self) -> T.Optional[T.List[tuple]]:
def _mask_shapes(self) -> list[tuple] | None:
""" list: The list of shape tuples for the mask input tensors for the model. Returns
``None`` if there is no mask input. """
if self._mask_inputs is None:
@ -141,7 +137,7 @@ class Loss():
self._set_loss_functions(model.output_names)
self._names.insert(0, "total")
def _set_loss_names(self, outputs: T.List[tf.Tensor]) -> None:
def _set_loss_names(self, outputs: list[tf.Tensor]) -> None:
""" Name the losses based on model output.
This is used for correct naming in the state file, for display purposes only.
@ -173,7 +169,7 @@ class Loss():
self._names.append(f"{name}_{side}{suffix}")
logger.debug(self._names)
def _get_function(self, name: str) -> T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
def _get_function(self, name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
""" Obtain the requested Loss function
Parameters
@ -191,7 +187,7 @@ class Loss():
logger.debug("Obtained loss function `%s` (%s)", name, retval)
return retval
def _set_loss_functions(self, output_names: T.List[str]):
def _set_loss_functions(self, output_names: list[str]):
""" Set the loss functions and their associated weights.
Adds the loss functions to the :attr:`functions` dictionary.
@ -251,7 +247,7 @@ class Loss():
mask_channel=mask_channel)
channel_idx += 1
def _get_mask_channels(self) -> T.List[int]:
def _get_mask_channels(self) -> list[int]:
""" Obtain the channels from the face targets that the masks reside in from the training
data generator.
@ -311,8 +307,8 @@ class Optimizer(): # pylint:disable=too-few-public-methods
{"beta_1": 0.5, "beta_2": 0.99, "epsilon": epsilon}),
"rms-prop": (optimizers.RMSprop, {"epsilon": epsilon})}
optimizer_info = valid_optimizers[optimizer]
self._optimizer: T.Callable = optimizer_info[0]
self._kwargs: T.Dict[str, T.Any] = optimizer_info[1]
self._optimizer: Callable = optimizer_info[0]
self._kwargs: dict[str, T.Any] = optimizer_info[1]
self._configure(learning_rate, autoclip)
logger.verbose("Using %s optimizer", optimizer.title()) # type:ignore[attr-defined]
@ -411,7 +407,7 @@ class Settings():
return mixedprecision.LossScaleOptimizer(optimizer) # pylint:disable=no-member
@classmethod
def _set_tf_settings(cls, allow_growth: bool, exclude_devices: T.List[int]) -> None:
def _set_tf_settings(cls, allow_growth: bool, exclude_devices: list[int]) -> None:
""" Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth.
Enables the Tensorflow allow_growth option if requested in the command line arguments
@ -480,8 +476,8 @@ class Settings():
return True
def _get_strategy(self,
strategy: Literal["default", "central-storage", "mirrored"]
) -> T.Optional[tf.distribute.Strategy]:
strategy: T.Literal["default", "central-storage", "mirrored"]
) -> tf.distribute.Strategy | None:
""" If we are running on Nvidia backend and the strategy is not ``None`` then return
the correct tensorflow distribution strategy, otherwise return ``None``.
@ -565,7 +561,7 @@ class Settings():
return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0")
def _get_mixed_precision_layers(self, layers: T.List[dict]) -> T.List[str]:
def _get_mixed_precision_layers(self, layers: list[dict]) -> list[str]:
""" Obtain the names of the layers in a mixed precision model that have their dtype policy
explicitly set to mixed-float16.
@ -595,7 +591,7 @@ class Settings():
logger.debug("Skipping unsupported layer: %s %s", layer["name"], dtype)
return retval
def _switch_precision(self, layers: T.List[dict], compatible: T.List[str]) -> None:
def _switch_precision(self, layers: list[dict], compatible: list[str]) -> None:
""" Switch a model's datatype between mixed-float16 and float32.
Parameters
@ -624,9 +620,9 @@ class Settings():
config["dtype"] = policy
def get_mixed_precision_layers(self,
build_func: T.Callable[[T.List[tf.keras.layers.Layer]],
build_func: Callable[[list[tf.keras.layers.Layer]],
tf.keras.models.Model],
inputs: T.List[tf.keras.layers.Layer]) -> T.List[str]:
inputs: list[tf.keras.layers.Layer]) -> list[str]:
""" Get and store the mixed precision layers from a full precision enabled model.
Parameters
@ -699,7 +695,7 @@ class Settings():
del model
return new_model
def strategy_scope(self) -> T.ContextManager:
def strategy_scope(self) -> ContextManager:
""" Return the strategy scope if we have set a strategy, otherwise return a null
context.

View file

@ -4,7 +4,6 @@
# pylint: disable=too-many-lines
from __future__ import annotations
import logging
import sys
import typing as T
from dataclasses import dataclass
@ -27,16 +26,10 @@ from lib.utils import get_tf_version, FaceswapError
from ._base import ModelBase, get_all_sub_models
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING:
from tensorflow import keras
from tensorflow import Tensor
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -65,14 +58,14 @@ class _EncoderInfo:
"""
keras_name: str
default_size: int
tf_min: T.Tuple[int, int] = (2, 0)
scaling: T.Tuple[int, int] = (0, 1)
tf_min: tuple[int, int] = (2, 0)
scaling: tuple[int, int] = (0, 1)
min_size: int = 32
enforce_for_weights: bool = False
color_order: Literal["bgr", "rgb"] = "rgb"
color_order: T.Literal["bgr", "rgb"] = "rgb"
_MODEL_MAPPING: T.Dict[str, _EncoderInfo] = {
_MODEL_MAPPING: dict[str, _EncoderInfo] = {
"densenet121": _EncoderInfo(
keras_name="DenseNet121", default_size=224),
"densenet169": _EncoderInfo(
@ -238,7 +231,7 @@ class Model(ModelBase):
model = new_model
return model
def _select_freeze_layers(self) -> T.List[str]:
def _select_freeze_layers(self) -> list[str]:
""" Process the selected frozen layers and replace the `keras_encoder` option with the
actual keras model name
@ -262,7 +255,7 @@ class Model(ModelBase):
logger.debug("Removing 'keras_encoder' for '%s'", arch)
return retval
def _get_input_shape(self) -> T.Tuple[int, int, int]:
def _get_input_shape(self) -> tuple[int, int, int]:
""" Obtain the input shape for the model.
Input shape is calculated from the selected Encoder's input size, scaled to the user
@ -316,7 +309,7 @@ class Model(ModelBase):
f"minimum version required is {tf_min} whilst you have version "
f"{tf_ver} installed.")
def build_model(self, inputs: T.List[Tensor]) -> keras.models.Model:
def build_model(self, inputs: list[Tensor]) -> keras.models.Model:
""" Create the model's structure.
Parameters
@ -341,7 +334,7 @@ class Model(ModelBase):
autoencoder = KModel(inputs, outputs, name=self.model_name)
return autoencoder
def _build_encoders(self, inputs: T.List[Tensor]) -> T.Dict[str, keras.models.Model]:
def _build_encoders(self, inputs: list[Tensor]) -> dict[str, keras.models.Model]:
""" Build the encoders for Phaze-A
Parameters
@ -362,7 +355,7 @@ class Model(ModelBase):
def _build_fully_connected(
self,
inputs: T.Dict[str, keras.models.Model]) -> T.Dict[str, T.List[keras.models.Model]]:
inputs: dict[str, keras.models.Model]) -> dict[str, list[keras.models.Model]]:
""" Build the fully connected layers for Phaze-A
Parameters
@ -407,8 +400,8 @@ class Model(ModelBase):
def _build_g_blocks(
self,
inputs: T.Dict[str, T.List[keras.models.Model]]
) -> T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]:
inputs: dict[str, list[keras.models.Model]]
) -> dict[str, list[keras.models.Model] | keras.models.Model]:
""" Build the g-block layers for Phaze-A.
If a g-block has not been selected for this model, then the original `inters` models are
@ -440,10 +433,9 @@ class Model(ModelBase):
logger.debug("G-Blocks: %s", retval)
return retval
def _build_decoders(
self,
inputs: T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]
) -> T.Dict[str, keras.models.Model]:
def _build_decoders(self,
inputs: dict[str, list[keras.models.Model] | keras.models.Model]
) -> dict[str, keras.models.Model]:
""" Build the encoders for Phaze-A
Parameters
@ -519,12 +511,12 @@ def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str)
return var_x
def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny", "upscale_fast",
"upscale_hybrid", "upsample2d"],
def _get_upscale_layer(method: T.Literal["resize_images", "subpixel", "upscale_dny",
"upscale_fast", "upscale_hybrid", "upsample2d"],
filters: int,
activation: T.Optional[str] = None,
upsamples: T.Optional[int] = None,
interpolation: T.Optional[str] = None) -> keras.layers.Layer:
activation: str | None = None,
upsamples: int | None = None,
interpolation: str | None = None) -> keras.layers.Layer:
""" Obtain an instance of the requested upscale method.
Parameters
@ -550,7 +542,7 @@ def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny
The selected configured upscale layer
"""
if method == "upsample2d":
kwargs: T.Dict[str, T.Union[str, int]] = {}
kwargs: dict[str, str | int] = {}
if upsamples:
kwargs["size"] = upsamples
if interpolation:
@ -571,7 +563,7 @@ def _get_curve(start_y: int,
end_y: int,
num_points: int,
scale: float,
mode: Literal["full", "cap_max", "cap_min"] = "full") -> T.List[int]:
mode: T.Literal["full", "cap_max", "cap_min"] = "full") -> list[int]:
""" Obtain a curve.
For the given start and end y values, return the y co-ordinates of a curve for the given
@ -660,13 +652,13 @@ class Encoder(): # pylint:disable=too-few-public-methods
config: dict
The model configuration options
"""
def __init__(self, input_shape: T.Tuple[int, ...], config: dict) -> None:
def __init__(self, input_shape: tuple[int, ...], config: dict) -> None:
self.input_shape = input_shape
self._config = config
self._input_shape = input_shape
@property
def _model_kwargs(self) -> T.Dict[str, T.Dict[str, T.Union[str, bool]]]:
def _model_kwargs(self) -> dict[str, dict[str, str | bool]]:
""" dict: Configuration option for architecture mapped to optional kwargs. """
return {"mobilenet": {"alpha": self._config["mobilenet_width"],
"depth_multiplier": self._config["mobilenet_depth"],
@ -677,7 +669,7 @@ class Encoder(): # pylint:disable=too-few-public-methods
"include_preprocessing": False}}
@property
def _selected_model(self) -> T.Tuple[_EncoderInfo, dict]:
def _selected_model(self) -> tuple[_EncoderInfo, dict]:
""" tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated
keyword arguments """
arch = self._config["enc_architecture"]
@ -832,7 +824,7 @@ class FullyConnected(): # pylint:disable=too-few-public-methods
The user configuration dictionary
"""
def __init__(self,
side: Literal["a", "b", "both", "gblock", "shared"],
side: T.Literal["a", "b", "both", "gblock", "shared"],
input_shape: tuple,
config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shape: %s)",
@ -992,12 +984,12 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will
generate the layers from the starting index to the final upscale. Default: ``None``
"""
_filters: T.List[int] = []
_filters: list[int] = []
def __init__(self,
side: Literal["a", "b", "both", "shared"],
side: T.Literal["a", "b", "both", "shared"],
config: dict,
layer_indicies: T.Optional[T.Tuple[int, int]] = None) -> None:
layer_indicies: tuple[int, int] | None = None) -> None:
logger.debug("Initializing: %s (side: %s, layer_indicies: %s)",
self.__class__.__name__, side, layer_indicies)
self._side = side
@ -1126,7 +1118,7 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
relu_alpha=0.2)(var_x)
return var_x
def __call__(self, inputs: T.Union[Tensor, T.List[Tensor]]) -> T.Union[Tensor, T.List[Tensor]]:
def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]:
""" Upscale Network.
Parameters
@ -1203,8 +1195,8 @@ class GBlock(): # pylint:disable=too-few-public-methods
The user configuration dictionary
"""
def __init__(self,
side: Literal["a", "b", "both"],
input_shapes: T.Union[list, tuple],
side: T.Literal["a", "b", "both"],
input_shapes: list | tuple,
config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shapes: %s)",
self.__class__.__name__, side, input_shapes)
@ -1284,8 +1276,8 @@ class Decoder(): # pylint:disable=too-few-public-methods
The user configuration dictionary
"""
def __init__(self,
side: Literal["a", "b", "both"],
input_shape: T.Tuple[int, int, int],
side: T.Literal["a", "b", "both"],
input_shape: tuple[int, int, int],
config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shape: %s)",
self.__class__.__name__, side, input_shape)

View file

@ -39,8 +39,6 @@
" the value saved in the state file with the updated value in config. If not
" provided this will default to True.
"""
from typing import List
_HELPTEXT: str = (
"Phaze-A Model by TorzDF, with thanks to BirbFakes.\n"
@ -48,7 +46,7 @@ _HELPTEXT: str = (
"inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to "
"understand the parameters better.")
_ENCODERS: List[str] = sorted([
_ENCODERS: list[str] = sorted([
"densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1",
"efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6",
"efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2",

View file

@ -9,7 +9,6 @@ with "original" unique code split out to the original plugin.
from __future__ import annotations
import logging
import os
import sys
import time
import typing as T
@ -23,23 +22,19 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
from lib.image import hex_to_rgb
from lib.training import PreviewDataGenerator, TrainingDataGenerator
from lib.training.generator import BatchType, DataGenerator
from lib.utils import FaceswapError, get_folder, get_image_paths, get_tf_version
from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.train._config import Config
if T.TYPE_CHECKING:
from collections.abc import Callable, Generator
from plugins.train.model._base import ModelBase
from lib.config import ConfigValueType
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def _get_config(plugin_name: str,
configfile: T.Optional[str] = None) -> T.Dict[str, ConfigValueType]:
configfile: str | None = None) -> dict[str, ConfigValueType]:
""" Return the configuration for the requested trainer.
Parameters
@ -80,9 +75,9 @@ class TrainerBase():
def __init__(self,
model: ModelBase,
images: T.Dict[Literal["a", "b"], T.List[str]],
images: dict[T.Literal["a", "b"], list[str]],
batch_size: int,
configfile: T.Optional[str]) -> None:
configfile: str | None) -> None:
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
self.__class__.__name__, model, batch_size)
self._model = model
@ -111,7 +106,7 @@ class TrainerBase():
self._images)
logger.debug("Initialized %s", self.__class__.__name__)
def _get_config(self, configfile: T.Optional[str]) -> T.Dict[str, ConfigValueType]:
def _get_config(self, configfile: str | None) -> dict[str, ConfigValueType]:
""" Get the saved training config options. Override any global settings with the setting
provided from the model's saved config.
@ -173,10 +168,9 @@ class TrainerBase():
self._samples.toggle_mask_display()
def train_one_step(self,
viewer: T.Optional[T.Callable[[np.ndarray, str], None]],
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a",
"input_b",
"output"], str]]) -> None:
viewer: Callable[[np.ndarray, str], None] | None,
timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
str] | None) -> None:
""" Running training on a batch of images for each side.
Triggered from the training cycle in :class:`scripts.train.Train`.
@ -217,7 +211,7 @@ class TrainerBase():
model_inputs, model_targets = self._feeder.get_batch()
try:
loss: T.List[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
except tf_errors.ResourceExhaustedError as err:
msg = ("You do not have enough GPU memory available to train the selected model at "
"the selected settings. You can try a number of things:"
@ -236,7 +230,7 @@ class TrainerBase():
self._model.snapshot()
self._update_viewers(viewer, timelapse_kwargs)
def _log_tensorboard(self, loss: T.List[float]) -> None:
def _log_tensorboard(self, loss: list[float]) -> None:
""" Log current loss to Tensorboard log files
Parameters
@ -250,19 +244,18 @@ class TrainerBase():
logs = {log[0]: log[1]
for log in zip(self._model.state.loss_names, loss)}
if get_tf_version() > (2, 7):
# Bug in TF 2.8/2.9/2.10 where batch recording got deleted.
# ref: https://github.com/keras-team/keras/issues/16173
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa pylint:disable=protected-access,not-context-manager
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa:E501 pylint:disable=protected-access,not-context-manager
for name, value in logs.items():
tf.summary.scalar(
"batch_" + name,
value,
step=self._tensorboard._train_step) # pylint:disable=protected-access
else:
self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
# TODO revert this code if fixed in tensorflow
# self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
def _collate_and_store_loss(self, loss: T.List[float]) -> T.List[float]:
def _collate_and_store_loss(self, loss: list[float]) -> list[float]:
""" Collate the loss into totals for each side.
The losses are summed into a total for each side. Loss totals are added to
@ -298,7 +291,7 @@ class TrainerBase():
logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore
return combined_loss
def _print_loss(self, loss: T.List[float]) -> None:
def _print_loss(self, loss: list[float]) -> None:
""" Outputs the loss for the current iteration to the console.
Parameters
@ -318,10 +311,9 @@ class TrainerBase():
"line: %s, error: %s", output, str(err))
def _update_viewers(self,
viewer: T.Optional[T.Callable[[np.ndarray, str], None]],
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a",
"input_b",
"output"], str]]) -> None:
viewer: Callable[[np.ndarray, str], None] | None,
timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
str] | None) -> None:
""" Update the preview viewer and timelapse output
Parameters
@ -371,10 +363,10 @@ class _Feeder():
The configuration for this trainer
"""
def __init__(self,
images: T.Dict[Literal["a", "b"], T.List[str]],
images: dict[T.Literal["a", "b"], list[str]],
model: ModelBase,
batch_size: int,
config: T.Dict[str, ConfigValueType]) -> None:
config: dict[str, ConfigValueType]) -> None:
logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)",
self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size,
config)
@ -383,16 +375,16 @@ class _Feeder():
self._batch_size = batch_size
self._config = config
self._feeds = {side: self._load_generator(side, False).minibatch_ab()
for side in get_args(Literal["a", "b"])}
for side in T.get_args(T.Literal["a", "b"])}
self._display_feeds = {"preview": self._set_preview_feed(), "timelapse": {}}
logger.debug("Initialized %s:", self.__class__.__name__)
def _load_generator(self,
side: Literal["a", "b"],
side: T.Literal["a", "b"],
is_display: bool,
batch_size: T.Optional[int] = None,
images: T.Optional[T.List[str]] = None) -> DataGenerator:
batch_size: int | None = None,
images: list[str] | None = None) -> DataGenerator:
""" Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder.
Parameters
@ -424,7 +416,7 @@ class _Feeder():
self._batch_size if batch_size is None else batch_size)
return retval
def _set_preview_feed(self) -> T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]]:
def _set_preview_feed(self) -> dict[T.Literal["a", "b"], Generator[BatchType, None, None]]:
""" Set the preview feed for this feeder.
Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically
@ -436,10 +428,10 @@ class _Feeder():
The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as
value.
"""
retval: T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]] = {}
retval: dict[T.Literal["a", "b"], Generator[BatchType, None, None]] = {}
num_images = self._config.get("preview_images", 14)
assert isinstance(num_images, int)
for side in get_args(Literal["a", "b"]):
for side in T.get_args(T.Literal["a", "b"]):
logger.debug("Setting preview feed: (side: '%s')", side)
preview_images = min(max(num_images, 2), 16)
batchsize = min(len(self._images[side]), preview_images)
@ -448,7 +440,7 @@ class _Feeder():
batch_size=batchsize).minibatch_ab()
return retval
def get_batch(self) -> T.Tuple[T.List[T.List[np.ndarray]], ...]:
def get_batch(self) -> tuple[list[list[np.ndarray]], ...]:
""" Get the feed data and the targets for each training side for feeding into the model's
train function.
@ -459,8 +451,8 @@ class _Feeder():
model_targets: list
The targets for the model for each side A and B
"""
model_inputs: T.List[T.List[np.ndarray]] = []
model_targets: T.List[T.List[np.ndarray]] = []
model_inputs: list[list[np.ndarray]] = []
model_targets: list[list[np.ndarray]] = []
for side in ("a", "b"):
side_feed, side_targets = next(self._feeds[side])
if self._model.config["learn_mask"]: # Add the face mask as it's own target
@ -473,7 +465,7 @@ class _Feeder():
return model_inputs, model_targets
def generate_preview(self, is_timelapse: bool = False
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]:
) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
""" Generate the images for preview window or timelapse
Parameters
@ -490,15 +482,15 @@ class _Feeder():
"""
logger.debug("Generating preview (is_timelapse: %s)", is_timelapse)
batchsizes: T.List[int] = []
feed: T.Dict[Literal["a", "b"], np.ndarray] = {}
samples: T.Dict[Literal["a", "b"], np.ndarray] = {}
masks: T.Dict[Literal["a", "b"], np.ndarray] = {}
batchsizes: list[int] = []
feed: dict[T.Literal["a", "b"], np.ndarray] = {}
samples: dict[T.Literal["a", "b"], np.ndarray] = {}
masks: dict[T.Literal["a", "b"], np.ndarray] = {}
# MyPy can't recurse into nested dicts to get the type :(
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]],
iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
self._display_feeds["timelapse" if is_timelapse else "preview"])
for side in get_args(Literal["a", "b"]):
for side in T.get_args(T.Literal["a", "b"]):
side_feed, side_samples = next(iterator[side])
batchsizes.append(len(side_samples[0]))
samples[side] = side_samples[0]
@ -513,10 +505,10 @@ class _Feeder():
def compile_sample(self,
image_count: int,
feed: T.Dict[Literal["a", "b"], np.ndarray],
samples: T.Dict[Literal["a", "b"], np.ndarray],
masks: T.Dict[Literal["a", "b"], np.ndarray]
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]:
feed: dict[T.Literal["a", "b"], np.ndarray],
samples: dict[T.Literal["a", "b"], np.ndarray],
masks: dict[T.Literal["a", "b"], np.ndarray]
) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
""" Compile the preview samples for display.
Parameters
@ -542,8 +534,8 @@ class _Feeder():
num_images = self._config.get("preview_images", 14)
assert isinstance(num_images, int)
num_images = min(image_count, num_images)
retval: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {}
for side in get_args(Literal["a", "b"]):
retval: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
for side in T.get_args(T.Literal["a", "b"]):
logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images)
retval[side] = [feed[side][0:num_images],
samples[side][0:num_images],
@ -552,7 +544,7 @@ class _Feeder():
return retval
def set_timelapse_feed(self,
images: T.Dict[Literal["a", "b"], T.List[str]],
images: dict[T.Literal["a", "b"], list[str]],
batch_size: int) -> None:
""" Set the time-lapse feed for this feeder.
@ -570,10 +562,10 @@ class _Feeder():
images, batch_size)
# MyPy can't recurse into nested dicts to get the type :(
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]],
iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
self._display_feeds["timelapse"])
for side in get_args(Literal["a", "b"]):
for side in T.get_args(T.Literal["a", "b"]):
imgs = images[side]
logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs))
@ -615,7 +607,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color)
self._model = model
self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"]
self.images: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {}
self.images: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
self._coverage_ratio = coverage_ratio
self._mask_opacity = mask_opacity / 100.0
self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255.
@ -639,8 +631,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
A compiled preview image ready for display or saving
"""
logger.debug("Showing sample")
feeds: T.Dict[Literal["a", "b"], np.ndarray] = {}
for idx, side in enumerate(get_args(Literal["a", "b"])):
feeds: dict[T.Literal["a", "b"], np.ndarray] = {}
for idx, side in enumerate(T.get_args(T.Literal["a", "b"])):
feed = self.images[side][0]
input_shape = self._model.model.input_shape[idx][1:]
if input_shape[0] / feed.shape[1] != 1.0:
@ -653,7 +645,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
@classmethod
def _resize_sample(cls,
side: Literal["a", "b"],
side: T.Literal["a", "b"],
sample: np.ndarray,
target_size: int) -> np.ndarray:
""" Resize a given image to the target size.
@ -684,7 +676,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
return retval
def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> T.Dict[str, np.ndarray]:
def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> dict[str, np.ndarray]:
""" Feed the samples to the model and return predictions
Parameters
@ -700,7 +692,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
List of :class:`numpy.ndarray` of predictions received from the model
"""
logger.debug("Getting Predictions")
preds: T.Dict[str, np.ndarray] = {}
preds: dict[str, np.ndarray] = {}
standard = self._model.model.predict([feed_a, feed_b], verbose=0)
swapped = self._model.model.predict([feed_b, feed_a], verbose=0)
@ -719,7 +711,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
return preds
def _compile_preview(self, predictions: T.Dict[str, np.ndarray]) -> np.ndarray:
def _compile_preview(self, predictions: dict[str, np.ndarray]) -> np.ndarray:
""" Compile predictions and images into the final preview image.
Parameters
@ -732,8 +724,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
:class:`numpy.ndarry`
A compiled preview image ready for display or saving
"""
figures: T.Dict[Literal["a", "b"], np.ndarray] = {}
headers: T.Dict[Literal["a", "b"], np.ndarray] = {}
figures: dict[T.Literal["a", "b"], np.ndarray] = {}
headers: dict[T.Literal["a", "b"], np.ndarray] = {}
for side, samples in self.images.items():
other_side = "a" if side == "b" else "b"
@ -761,9 +753,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
return np.clip(figure * 255, 0, 255).astype('uint8')
def _to_full_frame(self,
side: Literal["a", "b"],
samples: T.List[np.ndarray],
predictions: T.List[np.ndarray]) -> T.List[np.ndarray]:
side: T.Literal["a", "b"],
samples: list[np.ndarray],
predictions: list[np.ndarray]) -> list[np.ndarray]:
""" Patch targets and prediction images into images of model output size.
Parameters
@ -803,10 +795,10 @@ class _Samples(): # pylint:disable=too-few-public-methods
return images
def _process_full(self,
side: Literal["a", "b"],
side: T.Literal["a", "b"],
images: np.ndarray,
prediction_size: int,
color: T.Tuple[float, float, float]) -> np.ndarray:
color: tuple[float, float, float]) -> np.ndarray:
""" Add a frame overlay to preview images indicating the region of interest.
This applies the red border that appears in the preview images.
@ -847,7 +839,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Overlayed background. Shape: %s", images.shape)
return images
def _compile_masked(self, faces: T.List[np.ndarray], masks: np.ndarray) -> T.List[np.ndarray]:
def _compile_masked(self, faces: list[np.ndarray], masks: np.ndarray) -> list[np.ndarray]:
""" Add the mask to the faces for masked preview.
Places an opaque red layer over areas of the face that are masked out.
@ -866,7 +858,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
List of :class:`numpy.ndarray` faces with the opaque mask layer applied
"""
orig_masks = 1 - np.rint(masks)
masks3: T.Union[T.List[np.ndarray], np.ndarray] = []
masks3: list[np.ndarray] | np.ndarray = []
if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions
pred_masks = [1 - np.rint(face[..., -1])[..., None] for face in faces[-2:]]
@ -875,7 +867,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
else:
masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0)
retval: T.List[np.ndarray] = []
retval: list[np.ndarray] = []
alpha = 1.0 - self._mask_opacity
for previews, compiled_masks in zip(faces, masks3):
overlays = previews.copy()
@ -910,7 +902,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
return backgrounds
@classmethod
def _get_headers(cls, side: Literal["a", "b"], width: int) -> np.ndarray:
def _get_headers(cls, side: T.Literal["a", "b"], width: int) -> np.ndarray:
""" Set header row for the final preview frame
Parameters
@ -958,8 +950,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
@classmethod
def _duplicate_headers(cls,
headers: T.Dict[Literal["a", "b"], np.ndarray],
columns: int) -> T.Dict[Literal["a", "b"], np.ndarray]:
headers: dict[T.Literal["a", "b"], np.ndarray],
columns: int) -> dict[T.Literal["a", "b"], np.ndarray]:
""" Duplicate headers for the number of columns displayed for each side.
Parameters
@ -1008,7 +1000,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
mask_opacity: int,
mask_color: str,
feeder: _Feeder,
image_paths: T.Dict[Literal["a", "b"], T.List[str]]) -> None:
image_paths: dict[T.Literal["a", "b"], list[str]]) -> None:
logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, "
"mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)",
self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity,
@ -1042,8 +1034,8 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
logger.debug("Time-lapse output set to '%s'", self._output_file)
# Rewrite paths to pull from the training images so mask and face data can be accessed
images: T.Dict[Literal["a", "b"], T.List[str]] = {}
for side, input_ in zip(get_args(Literal["a", "b"]), (input_a, input_b)):
images: dict[T.Literal["a", "b"], list[str]] = {}
for side, input_ in zip(T.get_args(T.Literal["a", "b"]), (input_a, input_b)):
training_path = os.path.dirname(self._image_paths[side][0])
images[side] = [os.path.join(training_path, os.path.basename(pth))
for pth in get_image_paths(input_)]
@ -1054,7 +1046,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
self._feeder.set_timelapse_feed(images, batchsize)
logger.debug("Set up time-lapse")
def output_timelapse(self, timelapse_kwargs: T.Dict[Literal["input_a",
def output_timelapse(self, timelapse_kwargs: dict[T.Literal["input_a",
"input_b",
"output"], str]) -> None:
""" Generate the time-lapse samples and output the created time-lapse to the specified
@ -1068,7 +1060,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
"""
logger.debug("Ouputting time-lapse")
if not self._output_file:
self._setup(**T.cast(T.Dict[str, str], timelapse_kwargs))
self._setup(**T.cast(dict[str, str], timelapse_kwargs))
logger.debug("Getting time-lapse samples")
self._samples.images = self._feeder.generate_preview(is_timelapse=True)

View file

@ -1,19 +1,14 @@
tqdm>=4.64
# TESTED WITH PY3.10
tqdm>=4.65
psutil>=5.9.0
numexpr>=2.7.3; python_version < '3.9' # >=2.8.0 conflicts in Conda
numexpr>=2.8.3; python_version >= '3.9'
opencv-python>=4.6.0.0
pillow>=9.2.0
scikit-learn==1.0.2; python_version < '3.9' # AMD needs version 1.0.2 and 1.1.0 not available in Python 3.7
scikit-learn>=1.1.0; python_version >= '3.9'
numexpr>=2.8.4
numpy>=1.25.0
opencv-python>=4.7.0.0
pillow>=9.4.0
scikit-learn>=1.2.2
fastcluster>=1.2.6
matplotlib>=3.4.3,<3.6.0; python_version < '3.9' # >=3.5.0 conflicts in Conda
matplotlib>=3.5.1,<3.6.0; python_version >= '3.9'
imageio>=2.19.3
imageio-ffmpeg>=0.4.7
matplotlib>=3.7.1
imageio>=2.26.0
imageio-ffmpeg>=0.4.8
ffmpy>=0.3.0
# Exclude badly numbered Python2 version of nvidia-ml-py
nvidia-ml-py>=11.515,<300
tensorflow-probability<0.17
typing-extensions>=4.0.0
pywin32>=228 ; sys_platform == "win32"

View file

@ -1,11 +1,7 @@
protobuf>= 3.19.0,<3.20.0 # TF has started pulling in incompatible protobuf
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-macos>=2.8.0,<2.11.0
tensorflow-deps>=2.8.0,<2.11.0
tensorflow-metal>=0.4.0,<0.7.0
libblas # Conda only
-r _requirements_base.txt
tensorflow-macos>=2.10.0,<2.11.0
tensorflow-deps>=2.10.0,<2.11.0
tensorflow-metal>=0.6.0,<0.7.0
# These next 2 should have been installed, but some users complain of errors
decorator
cloudpickle

View file

@ -1,5 +1,2 @@
-r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-cpu>=2.7.0,<2.11.0
tensorflow-cpu>=2.10.0,<2.11.0

View file

@ -1,7 +1,4 @@
-r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-cpu>=2.10.0,<2.11.0
tensorflow-directml-plugin
comtypes

View file

@ -1,6 +1,5 @@
-r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-gpu>=2.7.0,<2.11.0
# Exclude badly numbered Python2 version of nvidia-ml-py
nvidia-ml-py>=11.525,<300
pynvx==1.0.0 ; sys_platform == "darwin"
tensorflow>=2.10.0,<2.11.0

View file

@ -1,5 +1,2 @@
-r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-rocm>=2.10.0,<2.11.0

View file

@ -26,13 +26,9 @@ from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.extract.pipeline import Extractor, ExtractMedia
from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if T.TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Callable
from plugins.convert.writer._base import Output
from plugins.train.model._base import ModelBase
from lib.align.aligned_face import CenteringType
@ -61,8 +57,8 @@ class ConvertItem:
The swapped faces returned from the model's predict function
"""
inbound: ExtractMedia
feed_faces: T.List[AlignedFace] = field(default_factory=list)
reference_faces: T.List[AlignedFace] = field(default_factory=list)
feed_faces: list[AlignedFace] = field(default_factory=list)
reference_faces: list[AlignedFace] = field(default_factory=list)
swapped_faces: np.ndarray = np.array([])
@ -307,8 +303,8 @@ class DiskIO():
# Extractor for on the fly detection
self._extractor = self._load_extractor()
self._queues: T.Dict[Literal["load", "save"], EventQueue] = {}
self._threads: T.Dict[Literal["load", "save"], MultiThread] = {}
self._queues: dict[T.Literal["load", "save"], EventQueue] = {}
self._threads: dict[T.Literal["load", "save"], MultiThread] = {}
self._init_threads()
logger.debug("Initialized %s", self.__class__.__name__)
@ -324,13 +320,13 @@ class DiskIO():
return self._writer.config.get("draw_transparent", False)
@property
def pre_encode(self) -> T.Optional[T.Callable[[np.ndarray], T.List[bytes]]]:
def pre_encode(self) -> Callable[[np.ndarray], list[bytes]] | None:
""" python function: Selected writer's pre-encode function, if it has one,
otherwise ``None`` """
dummy = np.zeros((20, 20, 3), dtype="uint8")
test = self._writer.pre_encode(dummy)
retval: T.Optional[T.Callable[[np.ndarray],
T.List[bytes]]] = None if test is None else self._writer.pre_encode
retval: Callable[[np.ndarray],
list[bytes]] | None = None if test is None else self._writer.pre_encode
logger.debug("Writer pre_encode function: %s", retval)
return retval
@ -384,7 +380,7 @@ class DiskIO():
return PluginLoader.get_converter("writer", self._args.writer)(*args,
configfile=configfile)
def _get_frame_ranges(self) -> T.Optional[T.List[T.Tuple[int, int]]]:
def _get_frame_ranges(self) -> list[tuple[int, int]] | None:
""" Obtain the frame ranges that are to be converted.
If frame ranges have been specified, then split the command line formatted arguments into
@ -422,7 +418,7 @@ class DiskIO():
logger.debug("frame ranges: %s", retval)
return retval
def _load_extractor(self) -> T.Optional[Extractor]:
def _load_extractor(self) -> Extractor | None:
""" Load the CV2-DNN Face Extractor Chain.
For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU.
@ -467,12 +463,12 @@ class DiskIO():
Creates the load and save queues and the load and save threads. Starts the threads.
"""
logger.debug("Initializing DiskIO Threads")
for task in get_args(Literal["load", "save"]):
for task in T.get_args(T.Literal["load", "save"]):
self._add_queue(task)
self._start_thread(task)
logger.debug("Initialized DiskIO Threads")
def _add_queue(self, task: Literal["load", "save"]) -> None:
def _add_queue(self, task: T.Literal["load", "save"]) -> None:
""" Add the queue to queue_manager and to :attr:`self._queues` for the given task.
Parameters
@ -490,7 +486,7 @@ class DiskIO():
self._queues[task] = queue_manager.get_queue(q_name)
logger.debug("Added queue for task: '%s'", task)
def _start_thread(self, task: Literal["load", "save"]) -> None:
def _start_thread(self, task: T.Literal["load", "save"]) -> None:
""" Create the thread for the given task, add it it :attr:`self._threads` and start it.
Parameters
@ -571,7 +567,7 @@ class DiskIO():
logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore
return skipframe
def _get_detected_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]:
def _get_detected_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
""" Return the detected faces for the given image.
If we have an alignments file, then the detected faces are created from that file. If
@ -597,7 +593,7 @@ class DiskIO():
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore
return detected_faces
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> T.List[DetectedFace]:
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> list[DetectedFace]:
""" Return detected faces from an alignments file.
Parameters
@ -644,7 +640,7 @@ class DiskIO():
tqdm.write(f"No alignment found for {frame_name}, skipping")
return have_alignments
def _detect_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]:
def _detect_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
""" Extract the face from a frame for On-The-Fly conversion.
Pulls detected faces out of the Extraction pipeline.
@ -779,7 +775,7 @@ class Predict():
""" int: The size in pixels of the Faceswap model output. """
return self._sizes["output"]
def _get_io_sizes(self) -> T.Dict[str, int]:
def _get_io_sizes(self) -> dict[str, int]:
""" Obtain the input size and output size of the model.
Returns
@ -896,9 +892,9 @@ class Predict():
"""
faces_seen = 0
consecutive_no_faces = 0
batch: T.List[ConvertItem] = []
batch: list[ConvertItem] = []
while True:
item: T.Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
item: T.Literal["EOF"] | ConvertItem = self._in_queue.get()
if item == "EOF":
logger.debug("EOF Received")
if batch: # Process out any remaining items
@ -938,7 +934,7 @@ class Predict():
self._out_queue.put("EOF")
logger.debug("Load queue complete")
def _process_batch(self, batch: T.List[ConvertItem], faces_seen: int):
def _process_batch(self, batch: list[ConvertItem], faces_seen: int):
""" Predict faces on the given batch of images and queue out to patch thread
Parameters
@ -1001,7 +997,7 @@ class Predict():
logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore
@staticmethod
def _compile_feed_faces(feed_faces: T.List[AlignedFace]) -> np.ndarray:
def _compile_feed_faces(feed_faces: list[AlignedFace]) -> np.ndarray:
""" Compile a batch of faces for feeding into the Predictor.
Parameters
@ -1020,7 +1016,7 @@ class Predict():
logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore
return retval
def _predict(self, feed_faces: np.ndarray, batch_size: T.Optional[int] = None) -> np.ndarray:
def _predict(self, feed_faces: np.ndarray, batch_size: int | None = None) -> np.ndarray:
""" Run the Faceswap models' prediction function.
Parameters
@ -1045,7 +1041,7 @@ class Predict():
logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore
inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size)
predicted: T.List[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
predicted: list[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
if self._model.color_order.lower() == "rgb":
predicted[0] = predicted[0][..., ::-1]
@ -1062,7 +1058,7 @@ class Predict():
logger.trace("Final shape: %s", retval.shape) # type:ignore
return retval
def _queue_out_frames(self, batch: T.List[ConvertItem], swapped_faces: np.ndarray) -> None:
def _queue_out_frames(self, batch: list[ConvertItem], swapped_faces: np.ndarray) -> None:
""" Compile the batch back to original frames and put to the Out Queue.
For batching, faces are split away from their frames. This compiles all detected faces
@ -1108,7 +1104,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
"""
def __init__(self,
arguments: Namespace,
input_images: T.List[np.ndarray],
input_images: list[np.ndarray],
alignments: Alignments) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._args = arguments
@ -1131,7 +1127,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
self._alignments.filter_faces(accept_dict, filter_out=False)
logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count)
def _get_face_metadata(self) -> T.Dict[str, T.List[int]]:
def _get_face_metadata(self) -> dict[str, list[int]]:
""" Check for the existence of an aligned directory for identifying which faces in the
target frames should be swapped. If it exists, scan the folder for face's metadata
@ -1140,7 +1136,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
dict
Dictionary of source frame names with a list of associated face indices to be skipped
"""
retval: T.Dict[str, T.List[int]] = {}
retval: dict[str, list[int]] = {}
input_aligned_dir = self._args.input_aligned_dir
if input_aligned_dir is None:

View file

@ -2,13 +2,13 @@
""" Main entry point to the extract process of FaceSwap """
from __future__ import annotations
import logging
import os
import sys
import typing as T
from argparse import Namespace
from multiprocessing import Process
from typing import List, Dict, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np
from tqdm import tqdm
@ -20,7 +20,7 @@ from lib.utils import get_folder, _image_extensions, _video_extensions
from plugins.extract.pipeline import Extractor, ExtractMedia
from scripts.fsmedia import Alignments, PostProcess, finalize
if TYPE_CHECKING:
if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderAlignmentsDict
# tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO?
@ -75,7 +75,7 @@ class Extract(): # pylint:disable=too-few-public-methods
self._args.nfilter,
self._extractor)
def _get_input_locations(self) -> List[str]:
def _get_input_locations(self) -> list[str]:
""" Obtain the full path to input locations. Will be a list of locations if batch mode is
selected, or a containing a single location if batch mode is not selected.
@ -194,8 +194,8 @@ class Filter():
"""
def __init__(self,
threshold: float,
filter_files: Optional[List[str]],
nfilter_files: Optional[List[str]],
filter_files: list[str] | None,
nfilter_files: list[str] | None,
extractor: Extractor) -> None:
logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s "
"extractor: %s)", self.__class__.__name__, threshold, filter_files,
@ -208,8 +208,8 @@ class Filter():
logger.debug("Filter not selected. Exiting %s", self.__class__.__name__)
return
self._embeddings: List[np.ndarray] = [np.array([]) for _ in self._filter_files]
self._nembeddings: List[np.ndarray] = [np.array([]) for _ in self._nfilter_files]
self._embeddings: list[np.ndarray] = [np.array([]) for _ in self._filter_files]
self._nembeddings: list[np.ndarray] = [np.array([]) for _ in self._nfilter_files]
self._extractor = extractor
self._get_embeddings()
@ -243,7 +243,7 @@ class Filter():
return retval
@classmethod
def _files_from_folder(cls, input_location: List[str]) -> List[str]:
def _files_from_folder(cls, input_location: list[str]) -> list[str]:
""" Test whether the input location is a folder and if so, return the list of contained
image files, otherwise return the original input location
@ -274,8 +274,8 @@ class Filter():
return retval
def _validate_inputs(self,
filter_files: Optional[List[str]],
nfilter_files: Optional[List[str]]) -> Tuple[List[str], List[str]]:
filter_files: list[str] | None,
nfilter_files: list[str] | None) -> tuple[list[str], list[str]]:
""" Validates that the given filter/nfilter files exist, are image files and are unique
Parameters
@ -293,7 +293,7 @@ class Filter():
List of full paths to nfilter files
"""
error = False
retval: List[List[str]] = []
retval: list[list[str]] = []
for files in (filter_files, nfilter_files):
filt_files = [] if files is None else self._files_from_folder(files)
@ -322,7 +322,7 @@ class Filter():
return filters, nfilters
@classmethod
def _identity_from_extracted(cls, filename) -> Tuple[np.ndarray, bool]:
def _identity_from_extracted(cls, filename) -> tuple[np.ndarray, bool]:
""" Test whether the given image is a faceswap extracted face and contains identity
information. If so, return the identity embedding
@ -404,7 +404,7 @@ class Filter():
embeddings[idx] = identities
return
def _identity_from_extractor(self, file_list: List[str], aligned: List[str]) -> None:
def _identity_from_extractor(self, file_list: list[str], aligned: list[str]) -> None:
""" Obtain the identity embeddings from the extraction pipeline
Parameters
@ -425,7 +425,7 @@ class Filter():
for phase in range(self._extractor.passes):
is_final = self._extractor.final_pass
detected_faces: Dict[str, ExtractMedia] = {}
detected_faces: dict[str, ExtractMedia] = {}
self._extractor.launch()
desc = "Obtaining reference face Identity"
if self._extractor.passes > 1:
@ -450,8 +450,8 @@ class Filter():
def _get_embeddings(self) -> None:
""" Obtain the embeddings for the given filter lists """
needs_extraction: List[str] = []
aligned: List[str] = []
needs_extraction: list[str] = []
aligned: list[str] = []
for files, embed in zip((self._filter_files, self._nfilter_files),
(self._embeddings, self._nembeddings)):
@ -494,14 +494,14 @@ class PipelineLoader():
image files that exist in :attr:`path` that are aligned faceswap images
"""
def __init__(self,
path: Union[str, List[str]],
path: str | list[str],
extractor: Extractor,
aligned_filenames: Optional[List[str]] = None) -> None:
aligned_filenames: list[str] | None = None) -> None:
logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)",
self.__class__.__name__, path, extractor, aligned_filenames)
self._images = ImagesLoader(path, fast_count=True)
self._extractor = extractor
self._threads: List[MultiThread] = []
self._threads: list[MultiThread] = []
self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames
logger.debug("Initialized %s", self.__class__.__name__)
@ -512,7 +512,7 @@ class PipelineLoader():
return self._images.is_video
@property
def file_list(self) -> List[str]:
def file_list(self) -> list[str]:
""" list: A full list of files in the source location. If the input is a video
then this is a list of dummy filenames as corresponding to an alignments file """
return self._images.file_list
@ -523,7 +523,7 @@ class PipelineLoader():
items that are to be skipped from the :attr:`skip_list`)"""
return self._images.process_count
def add_skip_list(self, skip_list: List[int]) -> None:
def add_skip_list(self, skip_list: list[int]) -> None:
""" Add a skip list to the :class:`ImagesLoader`
Parameters
@ -538,7 +538,7 @@ class PipelineLoader():
""" Launch the image loading pipeline """
self._threaded_redirector("load")
def reload(self, detected_faces: Dict[str, ExtractMedia]) -> None:
def reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
""" Reload images for multiple pipeline passes """
self._threaded_redirector("reload", (detected_faces, ))
@ -552,7 +552,7 @@ class PipelineLoader():
for thread in self._threads:
thread.join()
def _threaded_redirector(self, task: str, io_args: Optional[tuple] = None) -> None:
def _threaded_redirector(self, task: str, io_args: tuple | None = None) -> None:
""" Redirect image input/output tasks to relevant queues in background thread
Parameters
@ -587,7 +587,7 @@ class PipelineLoader():
load_queue.put("EOF")
logger.debug("Load Images: Complete")
def _reload(self, detected_faces: Dict[str, ExtractMedia]) -> None:
def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
""" Reload the images and pair to detected face
When the extraction pipeline is running in serial mode, images are reloaded from disk,
@ -652,7 +652,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
logger.debug("Initialized %s", self.__class__.__name__)
@property
def _save_interval(self) -> Optional[int]:
def _save_interval(self) -> int | None:
""" int: The number of frames to be processed between each saving of the alignments file if
it has been provided, otherwise ``None`` """
if hasattr(self._args, "save_interval"):
@ -718,7 +718,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
as_bytes=True)
for phase in range(self._extractor.passes):
is_final = self._extractor.final_pass
detected_faces: Dict[str, ExtractMedia] = {}
detected_faces: dict[str, ExtractMedia] = {}
self._extractor.launch()
self._loader.check_thread_error()
ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text
@ -774,7 +774,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
if not self._verify_output and faces_count > 1:
self._verify_output = True
def _output_faces(self, saver: Optional[ImagesSaver], extract_media: ExtractMedia) -> None:
def _output_faces(self, saver: ImagesSaver | None, extract_media: ExtractMedia) -> None:
""" Output faces to save thread
Set the face filename based on the frame name and put the face to the
@ -798,14 +798,14 @@ class _Extract(): # pylint:disable=too-few-public-methods
output_filename = f"{filename}_{real_face_id}.png"
aligned = face.aligned.face
assert aligned is not None
meta: PNGHeaderDict = dict(
alignments=face.to_png_meta(),
source=dict(alignments_version=self._alignments.version,
original_filename=output_filename,
face_index=real_face_id,
source_filename=os.path.basename(extract_media.filename),
source_is_video=self._loader.is_video,
source_frame_dims=extract_media.image_size))
meta: PNGHeaderDict = {
"alignments": face.to_png_meta(),
"source": {"alignments_version": self._alignments.version,
"original_filename": output_filename,
"face_index": real_face_id,
"source_filename": os.path.basename(extract_media.filename),
"source_is_video": self._loader.is_video,
"source_frame_dims": extract_media.image_size}}
image = encode_image(aligned, ".png", metadata=meta)
sub_folder = extract_media.sub_folders[face_id]
@ -820,6 +820,6 @@ class _Extract(): # pylint:disable=too-few-public-methods
continue
final_faces.append(face.to_alignment())
self._alignments.data[os.path.basename(extract_media.filename)] = dict(faces=final_faces,
video_meta={})
self._alignments.data[os.path.basename(extract_media.filename)] = {"faces": final_faces,
"video_meta": {}}
del extract_media

Some files were not shown because too many files have changed in this diff Show more