1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00

Rebase code (#1326)

* Remove tensorflow_probability requirement

* setup.py - fix progress bars

* requirements.txt: Remove pre python 3.9 packages

* update apple requirements.txt

* update INSTALL.md

* Remove python<3.9 code

* setup.py - fix Windows Installer

* typing: python3.9 compliant

* Update pytest and readthedocs python versions

* typing fixes

* Python Version updates
  - Reduce max version to 3.10
  - Default to 3.10 in installers
  - Remove incompatible 3.11 tests

* Update dependencies

* Downgrade imageio dep for Windows

* typing: merge optional unions and fixes

* Updates
  - min python version 3.10
  - typing to python 3.10 spec
  - remove pre-tf2.10 code
  - Add conda tests

* train: re-enable optimizer saving

* Update dockerfiles

* Update setup.py
  - Apple Conda deps to setup.py
  - Better Cuda + dependency handling

* bugfix: Patch logging to prevent Autograph errors

* Update dockerfiles

* Setup.py - Setup.py - stdout to utf-8

* Add more OSes to github Actions

* suppress mac-os end to end test
This commit is contained in:
torzdf 2023-06-27 11:27:47 +01:00 committed by GitHub
parent e4ba12ad2a
commit 6a3b674bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
130 changed files with 3035 additions and 3028 deletions

View file

@ -8,17 +8,85 @@ on:
- "**/README.md" - "**/README.md"
jobs: jobs:
build_linux: build_conda:
name: conda (${{ matrix.os }}, ${{ matrix.backend }})
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash -el {0}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
backend: ["nvidia", "cpu"]
include:
- os: "ubuntu-latest"
backend: "rocm"
- os: "windows-latest"
backend: "directml"
exclude:
# pynvx does not currently build for Python3.10 and without CUDA it may not build at all
- os: "macos-latest"
backend: "nvidia"
steps:
- uses: actions/checkout@v3
- name: Set cache date
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV
- name: Cache conda
uses: actions/cache@v3
env:
# Increase this value to manually reset cache
CACHE_NUMBER: 1
REQ_FILE: ./requirements/requirements_${{ matrix.backend }}.txt
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-${{ matrix.backend }}-conda-${{ env.CACHE_NUMBER }}-${{ env.DATE }}-${{ hashFiles('./requirements/requirements.txt', env.REQ_FILE) }}
- name: Set up Conda
uses: conda-incubator/setup-miniconda@v2
with:
python-version: "3.10"
auto-update-conda: true
activate-environment: faceswap
- name: Conda info
run: conda info && conda list
- name: Install
run: |
python setup.py --installer --${{ matrix.backend }}
pip install flake8 pylint mypy pytest pytest-mock wheel pytest-xvfb
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --select=E9,F63,F7,F82 --show-source
flake8 . --exit-zero
- name: MyPy Typing
continue-on-error: true
run: |
mypy .
- name: SysInfo
run: python -c "from lib.sysinfo import sysinfo ; print(sysinfo)"
- name: Simple Tests
# These backends will fail as GPU drivers not available
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml'
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
- name: End to End Tests
# These backends will fail as GPU drivers not available
# macOS fails on first extract test with 'died with <Signals.SIGSEGV: 11>'
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml' && matrix.os != 'macos-latest'
run: |
FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
build_linux:
name: "pip (ubuntu-latest, ${{ matrix.backend }})"
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.7", "3.8", "3.9"] python-version: ["3.10"]
backend: ["cpu"] backend: ["cpu"]
include: include:
- kbackend: "tensorflow" - backend: "cpu"
backend: "cpu"
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
@ -33,6 +101,8 @@ jobs:
pip install flake8 pylint mypy pytest pytest-mock pytest-xvfb wheel pip install flake8 pylint mypy pytest pytest-mock pytest-xvfb wheel
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
- name: List installed packages
run: pip freeze
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
@ -45,17 +115,18 @@ jobs:
mypy . mypy .
- name: Simple Tests - name: Simple Tests
run: | run: |
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" py.test -v tests/; FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
- name: End to End Tests - name: End to End Tests
run: | run: |
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" python tests/simple_tests.py; FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
build_windows: build_windows:
name: "pip (windows-latest, ${{ matrix.backend }})"
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.8", "3.9"] python-version: ["3.10"]
backend: ["cpu", "directml"] backend: ["cpu", "directml"]
include: include:
- backend: "cpu" - backend: "cpu"
@ -74,6 +145,8 @@ jobs:
pip install flake8 pylint mypy pytest pytest-mock wheel pip install flake8 pylint mypy pytest pytest-mock wheel
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
- name: List installed packages
run: pip freeze
- name: Set Backend EnvVar - name: Set Backend EnvVar
run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append
- name: Lint with flake8 - name: Lint with flake8

View file

@ -12,7 +12,7 @@ DIR_CONDA="$HOME/miniconda3"
CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda"
CONDA_TO_PATH=false CONDA_TO_PATH=false
ENV_NAME="faceswap" ENV_NAME="faceswap"
PYENV_VERSION="3.9" PYENV_VERSION="3.10"
DIR_FACESWAP="$HOME/faceswap" DIR_FACESWAP="$HOME/faceswap"
VERSION="nvidia" VERSION="nvidia"
@ -363,7 +363,7 @@ delete_env() {
} }
create_env() { create_env() {
# Create Python 3.8 env for faceswap # Create Python 3.10 env for faceswap
delete_env delete_env
info "Creating Conda Virtual Environment..." info "Creating Conda Virtual Environment..."
yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -q python="$PYENV_VERSION" -y yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -q python="$PYENV_VERSION" -y

View file

@ -22,7 +22,7 @@ InstallDir $PROFILE\faceswap
# Install cli flags # Install cli flags
!define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3" !define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3"
!define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}" !define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}"
!define flagsEnv "-y python=3.9" !define flagsEnv "-y python=3.10"
# Folders # Folders
Var ProgramData Var ProgramData

View file

@ -7,9 +7,9 @@ version: 2
# Set the version of Python and other tools you might need # Set the version of Python and other tools you might need
build: build:
os: ubuntu-20.04 os: ubuntu-22.04
tools: tools:
python: "3.8" python: "3.10"
# Build documentation in the docs/ directory with Sphinx # Build documentation in the docs/ directory with Sphinx
sphinx: sphinx:

View file

@ -1,19 +1,19 @@
FROM tensorflow/tensorflow:2.8.2 FROM ubuntu:22.04
# To disable tzdata and others from asking for input # To disable tzdata and others from asking for input
ENV DEBIAN_FRONTEND noninteractive ENV DEBIAN_FRONTEND noninteractive
ENV FACESWAP_BACKEND cpu
RUN apt-get update -qq -y \ RUN apt-get update -qq -y
&& apt-get install -y software-properties-common \ RUN apt-get upgrade -y
&& add-apt-repository -y ppa:jonathonf/ffmpeg-4 \ RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
&& apt-get update -qq -y \
&& apt-get install -y libsm6 libxrender1 libxext-dev python3-tk ffmpeg git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY ./requirements/_requirements_base.txt /opt/ RUN ln -s $(which python3) /usr/local/bin/python
RUN pip3 install --upgrade pip
RUN pip3 --no-cache-dir install -r /opt/_requirements_base.txt && rm /opt/_requirements_base.txt RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
WORKDIR "/faceswap"
RUN python -m pip install --upgrade pip
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_cpu.txt
WORKDIR "/srv"
CMD ["/bin/bash"] CMD ["/bin/bash"]

View file

@ -1,29 +1,19 @@
FROM nvidia/cuda:11.7.0-runtime-ubuntu18.04 FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
ARG DEBIAN_FRONTEND=noninteractive
#install python3.8 ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update ENV FACESWAP_BACKEND nvidia
RUN apt-get install software-properties-common -y
RUN add-apt-repository ppa:deadsnakes/ppa -y
RUN apt-get update
RUN apt-get install python3.8 -y
RUN apt-get install python3.8-distutils -y
RUN apt-get install python3.8-tk -y
RUN apt-get install curl -y
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
RUN python3.8 get-pip.py
RUN rm get-pip.py
# install requirements RUN apt-get update -qq -y
RUN apt-get install ffmpeg git -y RUN apt-get upgrade -y
COPY ./requirements/_requirements_base.txt /opt/ RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
COPY ./requirements/requirements_nvidia.txt /opt/
RUN python3.8 -m pip --no-cache-dir install -r /opt/requirements_nvidia.txt && rm /opt/_requirements_base.txt && rm /opt/requirements_nvidia.txt
RUN python3.8 -m pip install jupyter matplotlib tqdm RUN ln -s $(which python3) /usr/local/bin/python
RUN python3.8 -m pip install jupyter_http_over_ws
RUN jupyter serverextension enable --py jupyter_http_over_ws RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
RUN alias python=python3.8 WORKDIR "/faceswap"
RUN echo "alias python=python3.8" >> /root/.bashrc
WORKDIR "/notebooks" RUN python -m pip install --upgrade pip
CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"] RUN python -m pip install --upgrade pip
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_nvidia.txt
CMD ["/bin/bash"]

View file

@ -39,12 +39,9 @@
- [Setup](#setup-2) - [Setup](#setup-2)
- [About some of the options](#about-some-of-the-options) - [About some of the options](#about-some-of-the-options)
- [Docker Install Guide](#docker-install-guide) - [Docker Install Guide](#docker-install-guide)
- [Docker General](#docker-general) - [Docker CPU](#docker-cpu)
- [CUDA with Docker in 20 minutes.](#cuda-with-docker-in-20-minutes) - [Docker Nvidia](#docker-nvidia)
- [CUDA with Docker on Arch Linux](#cuda-with-docker-on-arch-linux) - [Run the project](#run-the-project)
- [Install docker](#install-docker)
- [A successful setup log, without docker.](#a-successful-setup-log-without-docker)
- [Run the project](#run-the-project)
- [Notes](#notes) - [Notes](#notes)
# Prerequisites # Prerequisites
@ -115,7 +112,7 @@ Reboot your PC, so that everything you have just installed gets registered.
- Select "Create" at the bottom - Select "Create" at the bottom
- In the pop up: - In the pop up:
- Give it the name: faceswap - Give it the name: faceswap
- **IMPORTANT**: Select python version 3.8 - **IMPORTANT**: Select python version 3.10
- Hit "Create" (NB: This may take a while as it will need to download Python) - Hit "Create" (NB: This may take a while as it will need to download Python)
![Anaconda virtual env setup](https://i.imgur.com/CLIDDfa.png) ![Anaconda virtual env setup](https://i.imgur.com/CLIDDfa.png)
@ -195,7 +192,7 @@ $ source ~/miniforge3/bin/activate
## Setup ## Setup
### Create and Activate the Environment ### Create and Activate the Environment
```sh ```sh
$ conda create --name faceswap python=3.9 $ conda create --name faceswap python=3.10
$ conda activate faceswap $ conda activate faceswap
``` ```
@ -225,7 +222,7 @@ Obtain git for your distribution from the [git website](https://git-scm.com/down
The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project. The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project.
- MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html) - MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html)
Alternatively you can install Python (>= 3.7-3.9 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu). Alternatively you can install Python (3.10 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu).
- Python distributions: - Python distributions:
- apt/yum install python3 (Linux) - apt/yum install python3 (Linux)
- [Installer](https://www.python.org/downloads/release/python-368/) (Windows) - [Installer](https://www.python.org/downloads/release/python-368/) (Windows)
@ -260,153 +257,83 @@ If setup fails for any reason you can still manually install the packages listed
# Docker Install Guide # Docker Install Guide
## Docker General This Faceswap repo contains Docker build scripts for CPU and Nvidia backends. The scripts will set up a Docker container for you and install the latest version of the Faceswap software.
<details>
<summary>Click to expand!</summary>
### CUDA with Docker in 20 minutes. You must first ensure that Docker is installed and running on your system. Follow the guide for downloading and installing Docker from their website:
1. Install Docker
https://www.docker.com/community-edition
2. Install Nvidia-Docker & Restart Docker Service - https://www.docker.com/get-started
https://github.com/NVIDIA/nvidia-docker
3. Build Docker Image For faceswap Once Docker is installed and running, follow the relevant steps for your chosen backend
## Docker CPU
```bash To run the CPU version of Faceswap follow these steps:
docker build -t deepfakes-gpu -f Dockerfile.gpu .
```
4. Mount faceswap volume and Run it 1. Build the Docker image For faceswap:
a). without `gui.tools.py` gui not working.
```bash
nvidia-docker run --rm -it -p 8888:8888 \
--hostname faceswap-gpu --name faceswap-gpu \
-v /opt/faceswap:/srv \
deepfakes-gpu
```
b). with gui. tools.py gui working.
Enable local access to X11 server
```bash
xhost +local:
``` ```
docker build \
Enable nvidia device if working under bumblebee -t faceswap-cpu \
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.cpu
```bash
echo ON > /proc/acpi/bbswitch
``` ```
2. Launch and enter the Faceswap container:
Create container a. For the **headless/command line** version of Faceswap run:
```bash ```
nvidia-docker run -p 8888:8888 \ docker run --rm -it faceswap-cpu
--hostname faceswap-gpu --name faceswap-gpu \ ```
-v /opt/faceswap:/srv \ You can then execute faceswap the standard way:
-v /tmp/.X11-unix:/tmp/.X11-unix \ ```
-e DISPLAY=unix$DISPLAY \ python faceswap.py --help
-e AUDIO_GID=`getent group audio | cut -d: -f3` \ ```
-e VIDEO_GID=`getent group video | cut -d: -f3` \ b. For the **GUI** version of Faceswap run:
-e GID=`id -g` \ ```
-e UID=`id -u` \ xhost +local: && \
deepfakes-gpu docker run --rm -it \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=${DISPLAY} \
faceswap-cpu
```
You can then launch the GUI with
```
python faceswap.py gui
```
## Docker Nvidia
To build the NVIDIA GPU version of Faceswap, follow these steps:
1. Nvidia Docker builds need extra resources to provide the Docker container with access to your GPU.
a. Follow the instructions to install and apply the `Nvidia Container Toolkit` for your distribution from:
- https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
b. If Docker is already running, restart it to pick up the changes made by the Nvidia Container Toolkit.
2. Build the Docker image For faceswap
``` ```
docker build \
Open a new terminal to interact with the project -t faceswap-gpu \
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.gpu
```bash
docker exec -it deepfakes-gpu /bin/bash
``` ```
1. Launch and enter the Faceswap container:
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt) a. For the **headless/command line** version of Faceswap run:
```
```bash docker run --runtime=nvidia --rm -it faceswap-gpu
python3.8 /srv/faceswap.py gui ```
``` You can then execute faceswap the standard way:
</details> ```
python faceswap.py --help
## CUDA with Docker on Arch Linux ```
b. For the **GUI** version of Faceswap run:
<details> ```
<summary>Click to expand!</summary> xhost +local: && \
docker run --runtime=nvidia --rm -it \
### Install docker -v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=${DISPLAY} \
```bash faceswap-gpu
sudo pacman -S docker ```
``` You can then launch the GUI with
```
The steps are same but Arch linux doesn't use nvidia-docker python faceswap.py gui
```
create container # Run the project
```bash
docker run -p 8888:8888 --gpus all --privileged -v /dev:/dev \
--hostname faceswap-gpu --name faceswap-gpu \
-v /mnt/hdd2/faceswap:/srv \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=unix$DISPLAY \
-e AUDIO_GID=`getent group audio | cut -d: -f3` \
-e VIDEO_GID=`getent group video | cut -d: -f3` \
-e GID=`id -g` \
-e UID=`id -u` \
deepfakes-gpu
```
Open a new terminal to interact with the project
```bash
docker exec -it deepfakes-gpu /bin/bash
```
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt)
**With `gui.tools.py` gui working.**
Enable local access to X11 server
```bash
xhost +local:
```
```bash
python3.8 /srv/faceswap.py gui
```
</details>
---
## A successful setup log, without docker.
```
INFO The tool provides tips for installation
and installs required python packages
INFO Setup in Linux 4.14.39-1-MANJARO
INFO Installed Python: 3.7.5 64bit
INFO Installed PIP: 10.0.1
Enable Docker? [Y/n] n
INFO Docker Disabled
Enable CUDA? [Y/n]
INFO CUDA Enabled
INFO CUDA version: 9.1
INFO cuDNN version: 7
WARNING Tensorflow has no official prebuild for CUDA 9.1 currently.
To continue, You have to build your own tensorflow-gpu.
Help: https://www.tensorflow.org/install/install_sources
Are System Dependencies met? [y/N] y
INFO Installing Missing Python Packages...
INFO Installing tensorflow-gpu
......
INFO Installing tqdm
INFO Installing matplotlib
INFO All python3 dependencies are met.
You are good to go.
```
## Run the project
Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options. Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options.
```bash ```bash

View file

@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath('../'))
sys.setrecursionlimit(1500) sys.setrecursionlimit(1500)
MOCK_MODULES = ["plaidml", "pynvx", "ctypes.windll", "comtypes"] MOCK_MODULES = ["pynvx", "ctypes.windll", "comtypes"]
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock() sys.modules[mod_name] = mock.Mock()

View file

@ -1,25 +1,21 @@
# NB Do not install from this requirements file # NB Do not install from this requirements file
# It is for documentation purposes only # It is for documentation purposes only
sphinx==5.0.2 sphinx==7.0.1
sphinx_rtd_theme==1.0.0 sphinx_rtd_theme==1.2.2
tqdm==4.64 tqdm==4.65
psutil==5.8.0 psutil==5.9.0
numexpr>=2.8.3 numexpr>=2.8.4
numpy>=1.18.0 numpy>=1.25.0
opencv-python>=4.5.5.0 opencv-python>=4.7.0.0
pillow==8.3.1 pillow==9.4.0
scikit-learn>=1.0.2 scikit-learn>=1.2.2
fastcluster>=1.2.4 fastcluster>=1.2.6
matplotlib==3.5.1 matplotlib==3.7.1
numexpr imageio==2.31.1
imageio==2.9.0 imageio-ffmpeg==0.4.8
imageio-ffmpeg==0.4.7 ffmpy==0.3.0
ffmpy==0.2.3 nvidia-ml-py==11.525
nvidia-ml-py<11.515
plaidml==0.7.0
pytest==7.2.0 pytest==7.2.0
pytest-mock==3.10.0 pytest-mock==3.10.0
tensorflow>=2.8.0,<2.9.0 tensorflow>=2.10.0,<2.11.0
tensorflow_probability<0.17
typing-extensions>=4.0.0

View file

@ -16,9 +16,8 @@ from lib.config import generate_configs # pylint:disable=wrong-import-position
_LANG = gettext.translation("faceswap", localedir="locales", fallback=True) _LANG = gettext.translation("faceswap", localedir="locales", fallback=True)
_ = _LANG.gettext _ = _LANG.gettext
if sys.version_info < (3, 7): if sys.version_info < (3, 10):
raise ValueError("This program requires at least python3.7") raise ValueError("This program requires at least python 3.10")
_PARSER = cli_args.FullHelpArgumentParser() _PARSER = cli_args.FullHelpArgumentParser()

View file

@ -3,22 +3,14 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
import sys import typing as T
from threading import Lock from threading import Lock
from typing import cast, Dict, Optional, Tuple
import cv2 import cv2
import numpy as np import numpy as np
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
CenteringType = T.Literal["face", "head", "legacy"]
CenteringType = Literal["face", "head", "legacy"]
_MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748], _MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748],
[0.300643, 0.034489], [0.403270, 0.077391], [0.596729, 0.077391], [0.300643, 0.034489], [0.403270, 0.077391], [0.596729, 0.077391],
@ -65,10 +57,10 @@ _MEAN_FACE_3D = np.array([[4.056931, -11.432347, 1.636229], # 8 chin LL
[0.0, -8.601736, 6.097667], # 45 mouth bottom C [0.0, -8.601736, 6.097667], # 45 mouth bottom C
[0.589441, -8.443925, 6.109526]]) # 44 mouth bottom L [0.589441, -8.443925, 6.109526]]) # 44 mouth bottom L
_EXTRACT_RATIOS = dict(legacy=0.375, face=0.5, head=0.625) _EXTRACT_RATIOS = {"legacy": 0.375, "face": 0.5, "head": 0.625}
def get_matrix_scaling(matrix: np.ndarray) -> Tuple[int, int]: def get_matrix_scaling(matrix: np.ndarray) -> tuple[int, int]:
""" Given a matrix, return the cv2 Interpolation method and inverse interpolation method for """ Given a matrix, return the cv2 Interpolation method and inverse interpolation method for
applying the matrix on an image. applying the matrix on an image.
@ -213,6 +205,149 @@ def get_centered_size(source_centering: CenteringType,
return retval return retval
class PoseEstimate():
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmarks aligned to 0.0 - 1.0 range
References
----------
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
"""
def __init__(self, landmarks: np.ndarray) -> None:
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
self._xyz_2d: np.ndarray | None = None
self._camera_matrix = self._get_camera_matrix()
self._rotation, self._translation = self._solve_pnp(landmarks)
self._offset = self._get_offset()
self._pitch_yaw_roll: tuple[float, float, float] = (0, 0, 0)
@property
def xyz_2d(self) -> np.ndarray:
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
if self._xyz_2d is None:
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
[0., 6., -2.3],
[0., 0., 3.7]]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
self._xyz_2d = xyz - self._offset["head"]
return self._xyz_2d
@property
def offset(self) -> dict[CenteringType, np.ndarray]:
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
from the center of the face (between the eyes) or center of the head (middle of skull)
rather than the nose area. """
return self._offset
@property
def pitch(self) -> float:
""" float: The pitch of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[0]
@property
def yaw(self) -> float:
""" float: The yaw of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[1]
@property
def roll(self) -> float:
""" float: The roll of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[2]
def _get_pitch_yaw_roll(self) -> None:
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
proj_matrix = np.zeros((3, 4), dtype="float32")
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
self._pitch_yaw_roll = T.cast(tuple[float, float, float], tuple(euler.squeeze()))
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
@classmethod
def _get_camera_matrix(cls) -> np.ndarray:
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
Returns
-------
:class:`numpy.ndarray`
An estimated camera matrix
"""
focal_length = 4
camera_matrix = np.array([[focal_length, 0, 0.5],
[0, focal_length, 0.5],
[0, 0, 1]], dtype="double")
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
return camera_matrix
def _solve_pnp(self, landmarks: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""" Solve the Perspective-n-Point for the given landmarks.
Takes 2D landmarks in world space and estimates the rotation and translation vectors
in 3D space.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmark co-ordinates relating to the original frame
Returns
-------
rotation: :class:`numpy.ndarray`
The solved rotation vector
translation: :class:`numpy.ndarray`
The solved translation vector
"""
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
points,
self._camera_matrix,
self._distortion_coefficients,
flags=cv2.SOLVEPNP_ITERATIVE)
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
points, rotation, translation)
return rotation, translation
def _get_offset(self) -> dict[CenteringType, np.ndarray]:
""" Obtain the offset between the original center of the extracted face to the new center
of the head in 2D space.
Returns
-------
:class:`numpy.ndarray`
The x, y offset of the new center from the old center.
"""
offset: dict[CenteringType, np.ndarray] = {"legacy": np.array([0.0, 0.0])}
points: dict[T.Literal["face", "head"], tuple[float, ...]] = {"head": (0.0, 0.0, -2.3),
"face": (0.0, -1.5, 4.2)}
for key, pnts in points.items():
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
logger.trace("center %s: %s", key, center) # type: ignore
offset[key] = center - (0.5, 0.5)
logger.trace("offset: %s", offset) # type: ignore
return offset
@dataclass @dataclass
class _FaceCache: # pylint:disable=too-many-instance-attributes class _FaceCache: # pylint:disable=too-many-instance-attributes
""" Cache for storing items related to a single aligned face. """ Cache for storing items related to a single aligned face.
@ -251,19 +386,19 @@ class _FaceCache: # pylint:disable=too-many-instance-attributes
cropped_slices: dict, optional cropped_slices: dict, optional
The slices for an input full head image and output cropped image. Default: `{}` The slices for an input full head image and output cropped image. Default: `{}`
""" """
pose: Optional["PoseEstimate"] = None pose: PoseEstimate | None = None
original_roi: Optional[np.ndarray] = None original_roi: np.ndarray | None = None
landmarks: Optional[np.ndarray] = None landmarks: np.ndarray | None = None
landmarks_normalized: Optional[np.ndarray] = None landmarks_normalized: np.ndarray | None = None
average_distance: float = 0.0 average_distance: float = 0.0
relative_eye_mouth_position: float = 0.0 relative_eye_mouth_position: float = 0.0
adjusted_matrix: Optional[np.ndarray] = None adjusted_matrix: np.ndarray | None = None
interpolators: Tuple[int, int] = (0, 0) interpolators: tuple[int, int] = (0, 0)
cropped_roi: Dict[CenteringType, np.ndarray] = field(default_factory=dict) cropped_roi: dict[CenteringType, np.ndarray] = field(default_factory=dict)
cropped_slices: Dict[CenteringType, Dict[Literal["in", "out"], cropped_slices: dict[CenteringType, dict[T.Literal["in", "out"],
Tuple[slice, slice]]] = field(default_factory=dict) tuple[slice, slice]]] = field(default_factory=dict)
_locks: Dict[str, Lock] = field(default_factory=dict) _locks: dict[str, Lock] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
""" Initialize the locks for the class parameters """ """ Initialize the locks for the class parameters """
@ -322,11 +457,11 @@ class AlignedFace():
""" """
def __init__(self, def __init__(self,
landmarks: np.ndarray, landmarks: np.ndarray,
image: Optional[np.ndarray] = None, image: np.ndarray | None = None,
centering: CenteringType = "face", centering: CenteringType = "face",
size: int = 64, size: int = 64,
coverage_ratio: float = 1.0, coverage_ratio: float = 1.0,
dtype: Optional[str] = None, dtype: str | None = None,
is_aligned: bool = False, is_aligned: bool = False,
is_legacy: bool = False) -> None: is_legacy: bool = False) -> None:
logger.trace("Initializing: %s (image shape: %s, centering: '%s', " # type: ignore logger.trace("Initializing: %s (image shape: %s, centering: '%s', " # type: ignore
@ -340,9 +475,9 @@ class AlignedFace():
self._dtype = dtype self._dtype = dtype
self._is_aligned = is_aligned self._is_aligned = is_aligned
self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head" self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head"
self._matrices = dict(legacy=_umeyama(landmarks[17:], _MEAN_FACE, True)[0:2], self._matrices = {"legacy": _umeyama(landmarks[17:], _MEAN_FACE, True)[0:2],
face=np.array([]), "face": np.array([]),
head=np.array([])) "head": np.array([])}
self._padding = self._padding_from_coverage(size, coverage_ratio) self._padding = self._padding_from_coverage(size, coverage_ratio)
self._cache = _FaceCache() self._cache = _FaceCache()
@ -353,7 +488,7 @@ class AlignedFace():
self._face if self._face is None else self._face.shape) self._face if self._face is None else self._face.shape)
@property @property
def centering(self) -> Literal["legacy", "head", "face"]: def centering(self) -> T.Literal["legacy", "head", "face"]:
""" str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """ """ str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """
return self._centering return self._centering
@ -382,7 +517,7 @@ class AlignedFace():
return self._matrices[self._centering] return self._matrices[self._centering]
@property @property
def pose(self) -> "PoseEstimate": def pose(self) -> PoseEstimate:
""" :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """ """ :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """
with self._cache.lock("pose"): with self._cache.lock("pose"):
if self._cache.pose is None: if self._cache.pose is None:
@ -405,7 +540,7 @@ class AlignedFace():
return self._cache.adjusted_matrix return self._cache.adjusted_matrix
@property @property
def face(self) -> Optional[np.ndarray]: def face(self) -> np.ndarray | None:
""" :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified """ :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified
:attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided :attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided
then an the attribute will return ``None``. """ then an the attribute will return ``None``. """
@ -450,7 +585,7 @@ class AlignedFace():
return self._cache.landmarks_normalized return self._cache.landmarks_normalized
@property @property
def interpolators(self) -> Tuple[int, int]: def interpolators(self) -> tuple[int, int]:
""" tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """ """ tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """
with self._cache.lock("interpolators"): with self._cache.lock("interpolators"):
if not any(self._cache.interpolators): if not any(self._cache.interpolators):
@ -487,7 +622,7 @@ class AlignedFace():
return self._cache.relative_eye_mouth_position return self._cache.relative_eye_mouth_position
@classmethod @classmethod
def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> Dict[CenteringType, int]: def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> dict[CenteringType, int]:
""" Return the image padding for a face from coverage_ratio set against a """ Return the image padding for a face from coverage_ratio set against a
pre-padded training image. pre-padded training image.
@ -504,7 +639,7 @@ class AlignedFace():
The padding required, in pixels for 'head', 'face' and 'legacy' face types The padding required, in pixels for 'head', 'face' and 'legacy' face types
""" """
retval = {_type: round((size * (coverage_ratio - (1 - _EXTRACT_RATIOS[_type]))) / 2) retval = {_type: round((size * (coverage_ratio - (1 - _EXTRACT_RATIOS[_type]))) / 2)
for _type in get_args(Literal["legacy", "face", "head"])} for _type in T.get_args(T.Literal["legacy", "face", "head"])}
logger.trace(retval) # type: ignore logger.trace(retval) # type: ignore
return retval return retval
@ -532,7 +667,7 @@ class AlignedFace():
invert, points, retval) invert, points, retval)
return retval return retval
def extract_face(self, image: Optional[np.ndarray]) -> Optional[np.ndarray]: def extract_face(self, image: np.ndarray | None) -> np.ndarray | None:
""" Extract the face from a source image and populate :attr:`face`. If an image is not """ Extract the face from a source image and populate :attr:`face`. If an image is not
provided then ``None`` is returned. provided then ``None`` is returned.
@ -605,7 +740,7 @@ class AlignedFace():
def _get_cropped_slices(self, def _get_cropped_slices(self,
image_size: int, image_size: int,
target_size: int, target_size: int,
) -> Dict[Literal["in", "out"], Tuple[slice, slice]]: ) -> dict[T.Literal["in", "out"], tuple[slice, slice]]:
""" Obtain the slices to turn a full head extract into an alternatively centered extract. """ Obtain the slices to turn a full head extract into an alternatively centered extract.
Parameters Parameters
@ -676,149 +811,6 @@ class AlignedFace():
return self._cache.cropped_roi[centering] return self._cache.cropped_roi[centering]
class PoseEstimate():
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmarks aligned to 0.0 - 1.0 range
References
----------
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
"""
def __init__(self, landmarks: np.ndarray) -> None:
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
self._xyz_2d: Optional[np.ndarray] = None
self._camera_matrix = self._get_camera_matrix()
self._rotation, self._translation = self._solve_pnp(landmarks)
self._offset = self._get_offset()
self._pitch_yaw_roll: Tuple[float, float, float] = (0, 0, 0)
@property
def xyz_2d(self) -> np.ndarray:
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
if self._xyz_2d is None:
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
[0., 6., -2.3],
[0., 0., 3.7]]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
self._xyz_2d = xyz - self._offset["head"]
return self._xyz_2d
@property
def offset(self) -> Dict[CenteringType, np.ndarray]:
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
from the center of the face (between the eyes) or center of the head (middle of skull)
rather than the nose area. """
return self._offset
@property
def pitch(self) -> float:
""" float: The pitch of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[0]
@property
def yaw(self) -> float:
""" float: The yaw of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[1]
@property
def roll(self) -> float:
""" float: The roll of the aligned face in eular angles """
if not any(self._pitch_yaw_roll):
self._get_pitch_yaw_roll()
return self._pitch_yaw_roll[2]
def _get_pitch_yaw_roll(self) -> None:
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
proj_matrix = np.zeros((3, 4), dtype="float32")
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
self._pitch_yaw_roll = cast(Tuple[float, float, float], tuple(euler.squeeze()))
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
@classmethod
def _get_camera_matrix(cls) -> np.ndarray:
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
Returns
-------
:class:`numpy.ndarray`
An estimated camera matrix
"""
focal_length = 4
camera_matrix = np.array([[focal_length, 0, 0.5],
[0, focal_length, 0.5],
[0, 0, 1]], dtype="double")
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
return camera_matrix
def _solve_pnp(self, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
""" Solve the Perspective-n-Point for the given landmarks.
Takes 2D landmarks in world space and estimates the rotation and translation vectors
in 3D space.
Parameters
----------
landmarks: :class:`numpy.ndarry`
The original 68 point landmark co-ordinates relating to the original frame
Returns
-------
rotation: :class:`numpy.ndarray`
The solved rotation vector
translation: :class:`numpy.ndarray`
The solved translation vector
"""
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
points,
self._camera_matrix,
self._distortion_coefficients,
flags=cv2.SOLVEPNP_ITERATIVE)
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
points, rotation, translation)
return rotation, translation
def _get_offset(self) -> Dict[CenteringType, np.ndarray]:
""" Obtain the offset between the original center of the extracted face to the new center
of the head in 2D space.
Returns
-------
:class:`numpy.ndarray`
The x, y offset of the new center from the old center.
"""
offset: Dict[CenteringType, np.ndarray] = dict(legacy=np.array([0.0, 0.0]))
points: Dict[Literal["face", "head"], Tuple[float, ...]] = dict(head=(0.0, 0.0, -2.3),
face=(0.0, -1.5, 4.2))
for key, pnts in points.items():
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
self._rotation,
self._translation,
self._camera_matrix,
self._distortion_coefficients)[0].squeeze()
logger.trace("center %s: %s", key, center) # type: ignore
offset[key] = center - (0.5, 0.5)
logger.trace("offset: %s", offset) # type: ignore
return offset
def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray: def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray:
"""Estimate N-D similarity transformation with or without scaling. """Estimate N-D similarity transformation with or without scaling.
@ -866,24 +858,24 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
if np.linalg.det(A) < 0: if np.linalg.det(A) < 0:
d[dim - 1] = -1 d[dim - 1] = -1
T = np.eye(dim + 1, dtype=np.double) retval = np.eye(dim + 1, dtype=np.double)
U, S, V = np.linalg.svd(A) U, S, V = np.linalg.svd(A)
# Eq. (40) and (43). # Eq. (40) and (43).
rank = np.linalg.matrix_rank(A) rank = np.linalg.matrix_rank(A)
if rank == 0: if rank == 0:
return np.nan * T return np.nan * retval
if rank == dim - 1: if rank == dim - 1:
if np.linalg.det(U) * np.linalg.det(V) > 0: if np.linalg.det(U) * np.linalg.det(V) > 0:
T[:dim, :dim] = U @ V retval[:dim, :dim] = U @ V
else: else:
s = d[dim - 1] s = d[dim - 1]
d[dim - 1] = -1 d[dim - 1] = -1
T[:dim, :dim] = U @ np.diag(d) @ V retval[:dim, :dim] = U @ np.diag(d) @ V
d[dim - 1] = s d[dim - 1] = s
else: else:
T[:dim, :dim] = U @ np.diag(d) @ V retval[:dim, :dim] = U @ np.diag(d) @ V
if estimate_scale: if estimate_scale:
# Eq. (41) and (42). # Eq. (41) and (42).
@ -891,7 +883,7 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
else: else:
scale = 1.0 scale = 1.0
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) retval[:dim, dim] = dst_mean - scale * (retval[:dim, :dim] @ src_mean.T)
T[:dim, :dim] *= scale retval[:dim, :dim] *= scale
return T return retval

View file

@ -1,24 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Alignments file functions for reading, writing and manipulating the data stored in a """ Alignments file functions for reading, writing and manipulating the data stored in a
serialized alignments file. """ serialized alignments file. """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
from datetime import datetime from datetime import datetime
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
from lib.serializer import get_serializer, get_serializer_from_filename from lib.serializer import get_serializer, get_serializer_from_filename
from lib.utils import FaceswapError from lib.utils import FaceswapError
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import TypedDict from collections.abc import Generator
else:
from typing import TypedDict
if TYPE_CHECKING:
from .aligned_face import CenteringType from .aligned_face import CenteringType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -35,49 +30,49 @@ _VERSION = 2.3
# TODO Convert these to Dataclasses # TODO Convert these to Dataclasses
class MaskAlignmentsFileDict(TypedDict): class MaskAlignmentsFileDict(T.TypedDict):
""" Typed Dictionary for storing Masks. """ """ Typed Dictionary for storing Masks. """
mask: bytes mask: bytes
affine_matrix: Union[List[float], np.ndarray] affine_matrix: list[float] | np.ndarray
interpolator: int interpolator: int
stored_size: int stored_size: int
stored_centering: "CenteringType" stored_centering: CenteringType
class PNGHeaderAlignmentsDict(TypedDict): class PNGHeaderAlignmentsDict(T.TypedDict):
""" Base Dictionary for storing a single faces' Alignment Information in Alignments files and """ Base Dictionary for storing a single faces' Alignment Information in Alignments files and
PNG Headers. """ PNG Headers. """
x: int x: int
y: int y: int
w: int w: int
h: int h: int
landmarks_xy: Union[List[float], np.ndarray] landmarks_xy: list[float] | np.ndarray
mask: Dict[str, MaskAlignmentsFileDict] mask: dict[str, MaskAlignmentsFileDict]
identity: Dict[str, List[float]] identity: dict[str, list[float]]
class AlignmentFileDict(PNGHeaderAlignmentsDict): class AlignmentFileDict(PNGHeaderAlignmentsDict):
""" Typed Dictionary for storing a single faces' Alignment Information in alignments files. """ """ Typed Dictionary for storing a single faces' Alignment Information in alignments files. """
thumb: Optional[np.ndarray] thumb: np.ndarray | None
class PNGHeaderSourceDict(TypedDict): class PNGHeaderSourceDict(T.TypedDict):
""" Dictionary for storing additional meta information in PNG headers """ """ Dictionary for storing additional meta information in PNG headers """
alignments_version: float alignments_version: float
original_filename: str original_filename: str
face_index: int face_index: int
source_filename: str source_filename: str
source_is_video: bool source_is_video: bool
source_frame_dims: Optional[Tuple[int, int]] source_frame_dims: tuple[int, int] | None
class AlignmentDict(TypedDict): class AlignmentDict(T.TypedDict):
""" Dictionary for holding all of the alignment information within a single alignment file """ """ Dictionary for holding all of the alignment information within a single alignment file """
faces: List[AlignmentFileDict] faces: list[AlignmentFileDict]
video_meta: Dict[str, Union[float, int]] video_meta: dict[str, float | int]
class PNGHeaderDict(TypedDict): class PNGHeaderDict(T.TypedDict):
""" Dictionary for storing all alignment and meta information in PNG Headers """ """ Dictionary for storing all alignment and meta information in PNG Headers """
alignments: PNGHeaderAlignmentsDict alignments: PNGHeaderAlignmentsDict
source: PNGHeaderSourceDict source: PNGHeaderSourceDict
@ -135,7 +130,7 @@ class Alignments():
return self._io.file return self._io.file
@property @property
def data(self) -> Dict[str, AlignmentDict]: def data(self) -> dict[str, AlignmentDict]:
""" dict: The loaded alignments :attr:`file` in dictionary form. """ """ dict: The loaded alignments :attr:`file` in dictionary form. """
return self._data return self._data
@ -146,7 +141,7 @@ class Alignments():
return self._io.have_alignments_file return self._io.have_alignments_file
@property @property
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]: def hashes_to_frame(self) -> dict[str, dict[str, int]]:
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
that the hash corresponds to. that the hash corresponds to.
@ -158,7 +153,7 @@ class Alignments():
return self._legacy.hashes_to_frame return self._legacy.hashes_to_frame
@property @property
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]: def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
corresponds to. The structure of the dictionary is: corresponds to. The structure of the dictionary is:
@ -170,10 +165,10 @@ class Alignments():
return self._legacy.hashes_to_alignment return self._legacy.hashes_to_alignment
@property @property
def mask_summary(self) -> Dict[str, int]: def mask_summary(self) -> dict[str, int]:
""" dict: The mask type names stored in the alignments :attr:`data` as key with the number """ dict: The mask type names stored in the alignments :attr:`data` as key with the number
of faces which possess the mask type as value. """ of faces which possess the mask type as value. """
masks: Dict[str, int] = {} masks: dict[str, int] = {}
for val in self._data.values(): for val in self._data.values():
for face in val["faces"]: for face in val["faces"]:
if face.get("mask", None) is None: if face.get("mask", None) is None:
@ -183,21 +178,20 @@ class Alignments():
return masks return masks
@property @property
def video_meta_data(self) -> Dict[str, Optional[Union[List[int], List[float]]]]: def video_meta_data(self) -> dict[str, list[int] | list[float] | None]:
""" dict: The frame meta data stored in the alignments file. If data does not exist in the """ dict: The frame meta data stored in the alignments file. If data does not exist in the
alignments file then ``None`` is returned for each Key """ alignments file then ``None`` is returned for each Key """
retval: Dict[str, Optional[Union[List[int], retval: dict[str, list[int] | list[float] | None] = {"pts_time": None, "keyframes": None}
List[float]]]] = dict(pts_time=None, keyframes=None) pts_time: list[float] = []
pts_time: List[float] = [] keyframes: list[int] = []
keyframes: List[int] = []
for idx, key in enumerate(sorted(self.data)): for idx, key in enumerate(sorted(self.data)):
if not self.data[key].get("video_meta", {}): if not self.data[key].get("video_meta", {}):
return retval return retval
meta = self.data[key]["video_meta"] meta = self.data[key]["video_meta"]
pts_time.append(cast(float, meta["pts_time"])) pts_time.append(T.cast(float, meta["pts_time"]))
if meta["keyframe"]: if meta["keyframe"]:
keyframes.append(idx) keyframes.append(idx)
retval = dict(pts_time=pts_time, keyframes=keyframes) retval = {"pts_time": pts_time, "keyframes": keyframes}
return retval return retval
@property @property
@ -211,7 +205,7 @@ class Alignments():
""" float: The alignments file version number. """ """ float: The alignments file version number. """
return self._io.version return self._io.version
def _load(self) -> Dict[str, AlignmentDict]: def _load(self) -> dict[str, AlignmentDict]:
""" Load the alignments data from the serialized alignments :attr:`file`. """ Load the alignments data from the serialized alignments :attr:`file`.
Populates :attr:`_version` with the alignment file's loaded version as well as returning Populates :attr:`_version` with the alignment file's loaded version as well as returning
@ -238,7 +232,7 @@ class Alignments():
""" """
return self._io.backup() return self._io.backup()
def save_video_meta_data(self, pts_time: List[float], keyframes: List[int]) -> None: def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None:
""" Save video meta data to the alignments file. """ Save video meta data to the alignments file.
If the alignments file does not have an entry for every frame (e.g. if Extract Every N If the alignments file does not have an entry for every frame (e.g. if Extract Every N
@ -262,10 +256,10 @@ class Alignments():
logger.info("Saving video meta information to Alignments file") logger.info("Saving video meta information to Alignments file")
for idx, pts in enumerate(pts_time): for idx, pts in enumerate(pts_time):
meta: Dict[str, Union[float, int]] = dict(pts_time=pts, keyframe=idx in keyframes) meta: dict[str, float | int] = {"pts_time": pts, "keyframe": idx in keyframes}
key = f"{basename}_{idx + 1:06d}.png" key = f"{basename}_{idx + 1:06d}.png"
if key not in self.data: if key not in self.data:
self.data[key] = dict(video_meta=meta, faces=[]) self.data[key] = {"video_meta": meta, "faces": []}
else: else:
self.data[key]["video_meta"] = meta self.data[key]["video_meta"] = meta
@ -285,8 +279,8 @@ class Alignments():
self._io.save() self._io.save()
@classmethod @classmethod
def _pad_leading_frames(cls, pts_time: List[float], keyframes: List[int]) -> Tuple[List[float], def _pad_leading_frames(cls, pts_time: list[float], keyframes: list[int]) -> tuple[list[float],
List[int]]: list[int]]:
""" Calculate the number of frames to pad the video by when the first frame is not """ Calculate the number of frames to pad the video by when the first frame is not
a key frame. a key frame.
@ -310,7 +304,7 @@ class Alignments():
""" """
start_pts = pts_time[0] start_pts = pts_time[0]
logger.debug("Video not cut on keyframe. Start pts: %s", start_pts) logger.debug("Video not cut on keyframe. Start pts: %s", start_pts)
gaps: List[float] = [] gaps: list[float] = []
prev_time = None prev_time = None
for item in pts_time: for item in pts_time:
if prev_time is not None: if prev_time is not None:
@ -360,7 +354,7 @@ class Alignments():
``True`` if the given frame_name exists within the alignments :attr:`data` and has at ``True`` if the given frame_name exists within the alignments :attr:`data` and has at
least 1 face associated with it, otherwise ``False`` least 1 face associated with it, otherwise ``False``
""" """
frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = bool(frame_data.get("faces", [])) retval = bool(frame_data.get("faces", []))
logger.trace("'%s': %s", frame_name, retval) # type:ignore logger.trace("'%s': %s", frame_name, retval) # type:ignore
return retval return retval
@ -384,7 +378,7 @@ class Alignments():
if not frame_name: if not frame_name:
retval = False retval = False
else: else:
frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = bool(len(frame_data.get("faces", [])) > 1) retval = bool(len(frame_data.get("faces", [])) > 1)
logger.trace("'%s': %s", frame_name, retval) # type:ignore logger.trace("'%s': %s", frame_name, retval) # type:ignore
return retval return retval
@ -414,7 +408,7 @@ class Alignments():
return retval return retval
# << DATA >> # # << DATA >> #
def get_faces_in_frame(self, frame_name: str) -> List[AlignmentFileDict]: def get_faces_in_frame(self, frame_name: str) -> list[AlignmentFileDict]:
""" Obtain the faces from :attr:`data` associated with a given frame_name. """ Obtain the faces from :attr:`data` associated with a given frame_name.
Parameters Parameters
@ -429,8 +423,8 @@ class Alignments():
The list of face dictionaries that appear within the requested frame_name The list of face dictionaries that appear within the requested frame_name
""" """
logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore
frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
return frame_data.get("faces", cast(List[AlignmentFileDict], [])) return frame_data.get("faces", T.cast(list[AlignmentFileDict], []))
def _count_faces_in_frame(self, frame_name: str) -> int: def _count_faces_in_frame(self, frame_name: str) -> int:
""" Return number of faces that appear within :attr:`data` for the given frame_name. """ Return number of faces that appear within :attr:`data` for the given frame_name.
@ -446,7 +440,7 @@ class Alignments():
int int
The number of faces that appear in the given frame_name The number of faces that appear in the given frame_name
""" """
frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
retval = len(frame_data.get("faces", [])) retval = len(frame_data.get("faces", []))
logger.trace(retval) # type:ignore logger.trace(retval) # type:ignore
return retval return retval
@ -497,7 +491,7 @@ class Alignments():
""" """
logger.debug("Adding face to frame_name: '%s'", frame_name) logger.debug("Adding face to frame_name: '%s'", frame_name)
if frame_name not in self._data: if frame_name not in self._data:
self._data[frame_name] = dict(faces=[], video_meta={}) self._data[frame_name] = {"faces": [], "video_meta": {}}
self._data[frame_name]["faces"].append(face) self._data[frame_name]["faces"].append(face)
retval = self._count_faces_in_frame(frame_name) - 1 retval = self._count_faces_in_frame(frame_name) - 1
logger.debug("Returning new face index: %s", retval) logger.debug("Returning new face index: %s", retval)
@ -520,7 +514,7 @@ class Alignments():
logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name) logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name)
self._data[frame_name]["faces"][face_index] = face self._data[frame_name]["faces"][face_index] = face
def filter_faces(self, filter_dict: Dict[str, List[int]], filter_out: bool = False) -> None: def filter_faces(self, filter_dict: dict[str, list[int]], filter_out: bool = False) -> None:
""" Remove faces from :attr:`data` based on a given filter list. """ Remove faces from :attr:`data` based on a given filter list.
Parameters Parameters
@ -549,7 +543,7 @@ class Alignments():
del frame_data["faces"][face_idx] del frame_data["faces"][face_idx]
# << GENERATORS >> # # << GENERATORS >> #
def yield_faces(self) -> Generator[Tuple[str, List[AlignmentFileDict], int, str], None, None]: def yield_faces(self) -> Generator[tuple[str, list[AlignmentFileDict], int, str], None, None]:
""" Generator to obtain all faces with meta information from :attr:`data`. The results """ Generator to obtain all faces with meta information from :attr:`data`. The results
are yielded by frame. are yielded by frame.
@ -715,7 +709,7 @@ class _IO():
logger.info("Updating alignments file to version %s", self._version) logger.info("Updating alignments file to version %s", self._version)
self.save() self.save()
def load(self) -> Dict[str, AlignmentDict]: def load(self) -> dict[str, AlignmentDict]:
""" Load the alignments data from the serialized alignments :attr:`file`. """ Load the alignments data from the serialized alignments :attr:`file`.
Populates :attr:`_version` with the alignment file's loaded version as well as returning Populates :attr:`_version` with the alignment file's loaded version as well as returning
@ -732,7 +726,7 @@ class _IO():
logger.info("Reading alignments from: '%s'", self._file) logger.info("Reading alignments from: '%s'", self._file)
data = self._serializer.load(self._file) data = self._serializer.load(self._file)
meta = data.get("__meta__", dict(version=1.0)) meta = data.get("__meta__", {"version": 1.0})
self._version = meta["version"] self._version = meta["version"]
data = data.get("__data__", data) data = data.get("__data__", data)
logger.debug("Loaded alignments") logger.debug("Loaded alignments")
@ -743,8 +737,8 @@ class _IO():
the location :attr:`file`. """ the location :attr:`file`. """
logger.debug("Saving alignments") logger.debug("Saving alignments")
logger.info("Writing alignments to: '%s'", self._file) logger.info("Writing alignments to: '%s'", self._file)
data = dict(__meta__=dict(version=self._version), data = {"__meta__": {"version": self._version},
__data__=self._alignments.data) "__data__": self._alignments.data}
self._serializer.save(self._file, data) self._serializer.save(self._file, data)
logger.debug("Saved alignments") logger.debug("Saved alignments")
@ -928,7 +922,7 @@ class _FileStructure(_Updater):
for key, val in self._alignments.data.items(): for key, val in self._alignments.data.items():
if not isinstance(val, list): if not isinstance(val, list):
continue continue
self._alignments.data[key] = dict(faces=val) self._alignments.data[key] = {"faces": val}
updated += 1 updated += 1
return updated return updated
@ -1078,11 +1072,11 @@ class _Legacy():
""" """
def __init__(self, alignments: Alignments) -> None: def __init__(self, alignments: Alignments) -> None:
self._alignments = alignments self._alignments = alignments
self._hashes_to_frame: Dict[str, Dict[str, int]] = {} self._hashes_to_frame: dict[str, dict[str, int]] = {}
self._hashes_to_alignment: Dict[str, AlignmentFileDict] = {} self._hashes_to_alignment: dict[str, AlignmentFileDict] = {}
@property @property
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]: def hashes_to_frame(self) -> dict[str, dict[str, int]]:
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
that the hash corresponds to. The structure of the dictionary is: that the hash corresponds to. The structure of the dictionary is:
@ -1105,7 +1099,7 @@ class _Legacy():
return self._hashes_to_frame return self._hashes_to_frame
@property @property
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]: def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
corresponds to. The structure of the dictionary is: corresponds to. The structure of the dictionary is:

View file

@ -1,12 +1,11 @@
#!/usr/bin python3 #!/usr/bin python3
""" Face and landmarks detection for faceswap.py """ """ Face and landmarks detection for faceswap.py """
from __future__ import annotations
import logging import logging
import sys
import os import os
import typing as T
from hashlib import sha1 from hashlib import sha1
from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from zlib import compress, decompress from zlib import compress, decompress
import cv2 import cv2
@ -18,14 +17,10 @@ from .alignments import (Alignments, AlignmentFileDict, MaskAlignmentsFileDict,
PNGHeaderAlignmentsDict, PNGHeaderDict, PNGHeaderSourceDict) PNGHeaderAlignmentsDict, PNGHeaderDict, PNGHeaderSourceDict)
from . import AlignedFace, get_adjusted_center, get_centered_size from . import AlignedFace, get_adjusted_center, get_centered_size
if TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Callable
from .aligned_face import CenteringType from .aligned_face import CenteringType
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -85,14 +80,14 @@ class DetectedFace():
dict of {**name** (`str`): :class:`Mask`}. dict of {**name** (`str`): :class:`Mask`}.
""" """
def __init__(self, def __init__(self,
image: Optional[np.ndarray] = None, image: np.ndarray | None = None,
left: Optional[int] = None, left: int | None = None,
width: Optional[int] = None, width: int | None = None,
top: Optional[int] = None, top: int | None = None,
height: Optional[int] = None, height: int | None = None,
landmarks_xy: Optional[np.ndarray] = None, landmarks_xy: np.ndarray | None = None,
mask: Optional[Dict[str, "Mask"]] = None, mask: dict[str, "Mask"] | None = None,
filename: Optional[str] = None) -> None: filename: str | None = None) -> None:
logger.trace("Initializing %s: (image: %s, left: %s, width: %s, top: %s, " # type: ignore logger.trace("Initializing %s: (image: %s, left: %s, width: %s, top: %s, " # type: ignore
"height: %s, landmarks_xy: %s, mask: %s, filename: %s)", "height: %s, landmarks_xy: %s, mask: %s, filename: %s)",
self.__class__.__name__, self.__class__.__name__,
@ -104,12 +99,12 @@ class DetectedFace():
self.top = top self.top = top
self.height = height self.height = height
self._landmarks_xy = landmarks_xy self._landmarks_xy = landmarks_xy
self._identity: Dict[str, np.ndarray] = {} self._identity: dict[str, np.ndarray] = {}
self.thumbnail: Optional[np.ndarray] = None self.thumbnail: np.ndarray | None = None
self.mask = {} if mask is None else mask self.mask = {} if mask is None else mask
self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None self._training_masks: tuple[bytes, tuple[int, int, int]] | None = None
self._aligned: Optional[AlignedFace] = None self._aligned: AlignedFace | None = None
logger.trace("Initialized %s", self.__class__.__name__) # type: ignore logger.trace("Initialized %s", self.__class__.__name__) # type: ignore
@property @property
@ -137,7 +132,7 @@ class DetectedFace():
return self.top + self.height return self.top + self.height
@property @property
def identity(self) -> Dict[str, np.ndarray]: def identity(self) -> dict[str, np.ndarray]:
""" dict: Identity mechanism as key, identity embedding as value. """ """ dict: Identity mechanism as key, identity embedding as value. """
return self._identity return self._identity
@ -147,7 +142,7 @@ class DetectedFace():
affine_matrix: np.ndarray, affine_matrix: np.ndarray,
interpolator: int, interpolator: int,
storage_size: int = 128, storage_size: int = 128,
storage_centering: "CenteringType" = "face") -> None: storage_centering: CenteringType = "face") -> None:
""" Add a :class:`Mask` to this detected face """ Add a :class:`Mask` to this detected face
The mask should be the original output from :mod:`plugins.extract.mask` The mask should be the original output from :mod:`plugins.extract.mask`
@ -209,7 +204,7 @@ class DetectedFace():
self._identity[name] = embedding self._identity[name] = embedding
def get_landmark_mask(self, def get_landmark_mask(self,
area: Literal["eye", "face", "mouth"], area: T.Literal["eye", "face", "mouth"],
blur_kernel: int, blur_kernel: int,
dilation: int) -> np.ndarray: dilation: int) -> np.ndarray:
""" Add a :class:`LandmarksMask` to this detected face """ Add a :class:`LandmarksMask` to this detected face
@ -235,7 +230,7 @@ class DetectedFace():
""" """
# TODO Face mask generation from landmarks # TODO Face mask generation from landmarks
logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore
areas = dict(mouth=[slice(48, 60)], eye=[slice(36, 42), slice(42, 48)]) areas = {"mouth": [slice(48, 60)], "eye": [slice(36, 42), slice(42, 48)]}
points = [self.aligned.landmarks[zone] points = [self.aligned.landmarks[zone]
for zone in areas[area]] for zone in areas[area]]
@ -250,7 +245,7 @@ class DetectedFace():
return lmmask.mask return lmmask.mask
def store_training_masks(self, def store_training_masks(self,
masks: List[Optional[np.ndarray]], masks: list[np.ndarray | None],
delete_masks: bool = False) -> None: delete_masks: bool = False) -> None:
""" Concatenate and compress the given training masks and store for retrieval. """ Concatenate and compress the given training masks and store for retrieval.
@ -273,7 +268,7 @@ class DetectedFace():
combined = np.concatenate(valid, axis=-1) combined = np.concatenate(valid, axis=-1)
self._training_masks = (compress(combined), combined.shape) self._training_masks = (compress(combined), combined.shape)
def get_training_masks(self) -> Optional[np.ndarray]: def get_training_masks(self) -> np.ndarray | None:
""" Obtain the decompressed combined training masks. """ Obtain the decompressed combined training masks.
Returns Returns
@ -312,7 +307,7 @@ class DetectedFace():
return alignment return alignment
def from_alignment(self, alignment: AlignmentFileDict, def from_alignment(self, alignment: AlignmentFileDict,
image: Optional[np.ndarray] = None, with_thumb: bool = False) -> None: image: np.ndarray | None = None, with_thumb: bool = False) -> None:
""" Set the attributes of this class from an alignments file and optionally load the face """ Set the attributes of this class from an alignments file and optionally load the face
into the ``image`` attribute. into the ``image`` attribute.
@ -342,7 +337,7 @@ class DetectedFace():
landmarks = alignment["landmarks_xy"] landmarks = alignment["landmarks_xy"]
if not isinstance(landmarks, np.ndarray): if not isinstance(landmarks, np.ndarray):
landmarks = np.array(landmarks, dtype="float32") landmarks = np.array(landmarks, dtype="float32")
self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32") self._identity = {T.cast(T.Literal["vggface2"], k): np.array(v, dtype="float32")
for k, v in alignment.get("identity", {}).items()} for k, v in alignment.get("identity", {}).items()}
self._landmarks_xy = landmarks.copy() self._landmarks_xy = landmarks.copy()
@ -403,7 +398,7 @@ class DetectedFace():
self._identity = {} self._identity = {}
for key, val in alignment.get("identity", {}).items(): for key, val in alignment.get("identity", {}).items():
assert key in ["vggface2"] assert key in ["vggface2"]
self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32") self._identity[T.cast(T.Literal["vggface2"], key)] = np.array(val, dtype="float32")
logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore
" height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width, " height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width,
self.top, self.height, self.landmarks_xy, self.mask, self.top, self.height, self.landmarks_xy, self.mask,
@ -417,10 +412,10 @@ class DetectedFace():
# <<< Aligned Face methods and properties >>> # # <<< Aligned Face methods and properties >>> #
def load_aligned(self, def load_aligned(self,
image: Optional[np.ndarray], image: np.ndarray | None,
size: int = 256, size: int = 256,
dtype: Optional[str] = None, dtype: str | None = None,
centering: "CenteringType" = "head", centering: CenteringType = "head",
coverage_ratio: float = 1.0, coverage_ratio: float = 1.0,
force: bool = False, force: bool = False,
is_aligned: bool = False, is_aligned: bool = False,
@ -507,22 +502,22 @@ class Mask():
""" """
def __init__(self, def __init__(self,
storage_size: int = 128, storage_size: int = 128,
storage_centering: "CenteringType" = "face") -> None: storage_centering: CenteringType = "face") -> None:
logger.trace("Initializing: %s (storage_size: %s, storage_centering: %s)", # type: ignore logger.trace("Initializing: %s (storage_size: %s, storage_centering: %s)", # type: ignore
self.__class__.__name__, storage_size, storage_centering) self.__class__.__name__, storage_size, storage_centering)
self.stored_size = storage_size self.stored_size = storage_size
self.stored_centering = storage_centering self.stored_centering = storage_centering
self._mask: Optional[bytes] = None self._mask: bytes | None = None
self._affine_matrix: Optional[np.ndarray] = None self._affine_matrix: np.ndarray | None = None
self._interpolator: Optional[int] = None self._interpolator: int | None = None
self._blur_type: Optional[Literal["gaussian", "normalized"]] = None self._blur_type: T.Literal["gaussian", "normalized"] | None = None
self._blur_passes: int = 0 self._blur_passes: int = 0
self._blur_kernel: Union[float, int] = 0 self._blur_kernel: float | int = 0
self._threshold = 0.0 self._threshold = 0.0
self._sub_crop_size = 0 self._sub_crop_size = 0
self._sub_crop_slices: Dict[Literal["in", "out"], List[slice]] = {} self._sub_crop_slices: dict[T.Literal["in", "out"], list[slice]] = {}
self.set_blur_and_threshold() self.set_blur_and_threshold()
logger.trace("Initialized: %s", self.__class__.__name__) # type: ignore logger.trace("Initialized: %s", self.__class__.__name__) # type: ignore
@ -648,7 +643,7 @@ class Mask():
def set_blur_and_threshold(self, def set_blur_and_threshold(self,
blur_kernel: int = 0, blur_kernel: int = 0,
blur_type: Optional[Literal["gaussian", "normalized"]] = "gaussian", blur_type: T.Literal["gaussian", "normalized"] | None = "gaussian",
blur_passes: int = 1, blur_passes: int = 1,
threshold: int = 0) -> None: threshold: int = 0) -> None:
""" Set the internal blur kernel and threshold amount for returned masks """ Set the internal blur kernel and threshold amount for returned masks
@ -679,7 +674,7 @@ class Mask():
def set_sub_crop(self, def set_sub_crop(self,
source_offset: np.ndarray, source_offset: np.ndarray,
target_offset: np.ndarray, target_offset: np.ndarray,
centering: "CenteringType", centering: CenteringType,
coverage_ratio: float = 1.0) -> None: coverage_ratio: float = 1.0) -> None:
""" Set the internal crop area of the mask to be returned. """ Set the internal crop area of the mask to be returned.
@ -831,9 +826,9 @@ class LandmarksMask(Mask):
The amount of dilation to apply to the mask. `0` for none. Default: `0` The amount of dilation to apply to the mask. `0` for none. Default: `0`
""" """
def __init__(self, def __init__(self,
points: List[np.ndarray], points: list[np.ndarray],
storage_size: int = 128, storage_size: int = 128,
storage_centering: "CenteringType" = "face", storage_centering: CenteringType = "face",
dilation: int = 0) -> None: dilation: int = 0) -> None:
super().__init__(storage_size=storage_size, storage_centering=storage_centering) super().__init__(storage_size=storage_size, storage_centering=storage_centering)
self._points = points self._points = points
@ -907,9 +902,9 @@ class BlurMask(): # pylint:disable=too-few-public-methods
(128, 128, 1) (128, 128, 1)
""" """
def __init__(self, def __init__(self,
blur_type: Literal["gaussian", "normalized"], blur_type: T.Literal["gaussian", "normalized"],
mask: np.ndarray, mask: np.ndarray,
kernel: Union[int, float], kernel: int | float,
is_ratio: bool = False, is_ratio: bool = False,
passes: int = 1) -> None: passes: int = 1) -> None:
logger.trace("Initializing %s: (blur_type: '%s', mask_shape: %s, " # type: ignore logger.trace("Initializing %s: (blur_type: '%s', mask_shape: %s, " # type: ignore
@ -943,33 +938,30 @@ class BlurMask(): # pylint:disable=too-few-public-methods
def _multipass_factor(self) -> float: def _multipass_factor(self) -> float:
""" For multiple passes the kernel must be scaled down. This value is """ For multiple passes the kernel must be scaled down. This value is
different for box filter and gaussian """ different for box filter and gaussian """
factor = dict(gaussian=0.8, normalized=0.5) factor = {"gaussian": 0.8, "normalized": 0.5}
return factor[self._blur_type] return factor[self._blur_type]
@property @property
def _sigma(self) -> Literal[0]: def _sigma(self) -> T.Literal[0]:
""" int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """ """ int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """
return 0 return 0
@property @property
def _func_mapping(self) -> Dict[Literal["gaussian", "normalized"], Callable]: def _func_mapping(self) -> dict[T.Literal["gaussian", "normalized"], Callable]:
""" dict: :attr:`_blur_type` mapped to cv2 Function name. """ """ dict: :attr:`_blur_type` mapped to cv2 Function name. """
return dict(gaussian=cv2.GaussianBlur, # pylint: disable = no-member return {"gaussian": cv2.GaussianBlur, "normalized": cv2.blur}
normalized=cv2.blur) # pylint: disable = no-member
@property @property
def _kwarg_requirements(self) -> Dict[Literal["gaussian", "normalized"], List[str]]: def _kwarg_requirements(self) -> dict[T.Literal["gaussian", "normalized"], list[str]]:
""" dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """ """ dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """
return dict(gaussian=["ksize", "sigmaX"], return {"gaussian": ['ksize', 'sigmaX'], "normalized": ['ksize']}
normalized=["ksize"])
@property @property
def _kwarg_mapping(self) -> Dict[str, Union[int, Tuple[int, int]]]: def _kwarg_mapping(self) -> dict[str, int | tuple[int, int]]:
""" dict: cv2 function keyword arguments mapped to their parameters. """ """ dict: cv2 function keyword arguments mapped to their parameters. """
return dict(ksize=self._kernel_size, return {"ksize": self._kernel_size, "sigmaX": self._sigma}
sigmaX=self._sigma)
def _get_kernel_size(self, kernel: Union[int, float], is_ratio: bool) -> int: def _get_kernel_size(self, kernel: int | float, is_ratio: bool) -> int:
""" Set the kernel size to absolute value. """ Set the kernel size to absolute value.
If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and
@ -999,7 +991,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
return kernel_size return kernel_size
@staticmethod @staticmethod
def _get_kernel_tuple(kernel_size: int) -> Tuple[int, int]: def _get_kernel_tuple(kernel_size: int) -> tuple[int, int]:
""" Make sure kernel_size is odd and return it as a tuple. """ Make sure kernel_size is odd and return it as a tuple.
Parameters Parameters
@ -1017,7 +1009,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
logger.trace(retval) # type: ignore logger.trace(retval) # type: ignore
return retval return retval
def _get_kwargs(self) -> Dict[str, Union[int, Tuple[int, int]]]: def _get_kwargs(self) -> dict[str, int | tuple[int, int]]:
""" dict: the valid keyword arguments for the requested :attr:`_blur_type` """ """ dict: the valid keyword arguments for the requested :attr:`_blur_type` """
retval = {kword: self._kwarg_mapping[kword] retval = {kword: self._kwarg_mapping[kword]
for kword in self._kwarg_requirements[self._blur_type]} for kword in self._kwarg_requirements[self._blur_type]}
@ -1025,11 +1017,11 @@ class BlurMask(): # pylint:disable=too-few-public-methods
return retval return retval
_HASHES_SEEN: Dict[str, Dict[str, int]] = {} _HASHES_SEEN: dict[str, dict[str, int]] = {}
def update_legacy_png_header(filename: str, alignments: Alignments def update_legacy_png_header(filename: str, alignments: Alignments
) -> Optional[PNGHeaderDict]: ) -> PNGHeaderDict | None:
""" Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for """ Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for
the face in the png exif header for the given filename with the given alignment data. the face in the png exif header for the given filename with the given alignment data.

View file

@ -7,7 +7,7 @@ as well as adding a mechanism for indicating to the GUI how specific options sho
import argparse import argparse
import os import os
from typing import Any, List, Optional, Tuple, Union import typing as T
# << FILE HANDLING >> # << FILE HANDLING >>
@ -69,7 +69,7 @@ class FileFullPaths(_FullPaths):
>>> filetypes="video))" >>> filetypes="video))"
""" """
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None: def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.filetypes = filetypes self.filetypes = filetypes
@ -111,7 +111,7 @@ class FilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods
>>> filetypes="image", >>> filetypes="image",
>>> nargs="+")) >>> nargs="+"))
""" """
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None: def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
if kwargs.get("nargs", None) is None: if kwargs.get("nargs", None) is None:
opt = kwargs["option_strings"] opt = kwargs["option_strings"]
raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}") raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}")
@ -250,8 +250,8 @@ class ContextFullPaths(FileFullPaths):
# pylint: disable=too-few-public-methods, too-many-arguments # pylint: disable=too-few-public-methods, too-many-arguments
def __init__(self, def __init__(self,
*args, *args,
filetypes: Optional[str] = None, filetypes: str | None = None,
action_option: Optional[str] = None, action_option: str | None = None,
**kwargs) -> None: **kwargs) -> None:
opt = kwargs["option_strings"] opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None: if kwargs.get("nargs", None) is not None:
@ -263,7 +263,7 @@ class ContextFullPaths(FileFullPaths):
super().__init__(*args, filetypes=filetypes, **kwargs) super().__init__(*args, filetypes=filetypes, **kwargs)
self.action_option = action_option self.action_option = action_option
def _get_kwargs(self) -> List[Tuple[str, Any]]: def _get_kwargs(self) -> list[tuple[str, T.Any]]:
names = ["option_strings", names = ["option_strings",
"dest", "dest",
"nargs", "nargs",
@ -382,8 +382,8 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
*args, *args,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None, min_max: tuple[int, int] | tuple[float, float] | None = None,
rounding: Optional[int] = None, rounding: int | None = None,
**kwargs) -> None: **kwargs) -> None:
opt = kwargs["option_strings"] opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None: if kwargs.get("nargs", None) is not None:
@ -401,7 +401,7 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
self.min_max = min_max self.min_max = min_max
self.rounding = rounding self.rounding = rounding
def _get_kwargs(self) -> List[Tuple[str, Any]]: def _get_kwargs(self) -> list[tuple[str, T.Any]]:
names = ["option_strings", names = ["option_strings",
"dest", "dest",
"nargs", "nargs",

View file

@ -8,8 +8,7 @@ import logging
import re import re
import sys import sys
import textwrap import textwrap
import typing as T
from typing import Any, Dict, List, NoReturn, Optional
from lib.utils import get_backend from lib.utils import get_backend
from lib.gpu_stats import GPUStats from lib.gpu_stats import GPUStats
@ -30,7 +29,7 @@ _ = _LANG.gettext
class FullHelpArgumentParser(argparse.ArgumentParser): class FullHelpArgumentParser(argparse.ArgumentParser):
""" Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """ """ Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """
def error(self, message: str) -> NoReturn: def error(self, message: str) -> T.NoReturn:
self.print_help(sys.stderr) self.print_help(sys.stderr)
self.exit(2, f"{self.prog}: error: {message}\n") self.exit(2, f"{self.prog}: error: {message}\n")
@ -51,11 +50,11 @@ class SmartFormatter(argparse.HelpFormatter):
prog: str, prog: str,
indent_increment: int = 2, indent_increment: int = 2,
max_help_position: int = 24, max_help_position: int = 24,
width: Optional[int] = None) -> None: width: int | None = None) -> None:
super().__init__(prog, indent_increment, max_help_position, width) super().__init__(prog, indent_increment, max_help_position, width)
self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII) self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII)
def _split_lines(self, text: str, width: int) -> List[str]: def _split_lines(self, text: str, width: int) -> list[str]:
""" Split the given text by the given display width. """ Split the given text by the given display width.
If the text is not prefixed with "R|" then the standard If the text is not prefixed with "R|" then the standard
@ -138,7 +137,7 @@ class FaceSwapArgs():
return "" return ""
@staticmethod @staticmethod
def get_argument_list() -> List[Dict[str, Any]]: def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for the current command. """ Returns the argument list for the current command.
The argument list should be a list of dictionaries pertaining to each option for a command. The argument list should be a list of dictionaries pertaining to each option for a command.
@ -152,11 +151,11 @@ class FaceSwapArgs():
list list
The list of command line options for the given command The list of command line options for the given command
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
return argument_list return argument_list
@staticmethod @staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]: def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the optional argument list for the current command. """ Returns the optional argument list for the current command.
The optional arguments list is not always required, but is used when there are shared The optional arguments list is not always required, but is used when there are shared
@ -167,11 +166,11 @@ class FaceSwapArgs():
list list
The list of optional command line options for the given command The list of optional command line options for the given command
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
return argument_list return argument_list
@staticmethod @staticmethod
def _get_global_arguments() -> List[Dict[str, Any]]: def _get_global_arguments() -> list[dict[str, T.Any]]:
""" Returns the global Arguments list that are required for ALL commands in Faceswap. """ Returns the global Arguments list that are required for ALL commands in Faceswap.
This method should NOT be overridden. This method should NOT be overridden.
@ -181,7 +180,7 @@ class FaceSwapArgs():
list list
The list of global command line options for all Faceswap commands. The list of global command line options for all Faceswap commands.
""" """
global_args: List[Dict[str, Any]] = [] global_args: list[dict[str, T.Any]] = []
if _GPUS: if _GPUS:
global_args.append(dict( global_args.append(dict(
opts=("-X", "--exclude-gpus"), opts=("-X", "--exclude-gpus"),
@ -302,7 +301,7 @@ class ExtractConvertArgs(FaceSwapArgs):
""" """
@staticmethod @staticmethod
def get_argument_list() -> List[Dict[str, Any]]: def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for shared Extract and Convert arguments. """ Returns the argument list for shared Extract and Convert arguments.
Returns Returns
@ -310,7 +309,7 @@ class ExtractConvertArgs(FaceSwapArgs):
list list
The list of command line options for the given Extract and Convert The list of command line options for the given Extract and Convert
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict( argument_list.append(dict(
opts=("-i", "--input-dir"), opts=("-i", "--input-dir"),
action=DirOrFileFullPaths, action=DirOrFileFullPaths,
@ -362,7 +361,7 @@ class ExtractArgs(ExtractConvertArgs):
"Extraction plugins can be configured in the 'Settings' Menu") "Extraction plugins can be configured in the 'Settings' Menu")
@staticmethod @staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]: def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the argument list unique to the Extract command. """ Returns the argument list unique to the Extract command.
Returns Returns
@ -377,7 +376,7 @@ class ExtractArgs(ExtractConvertArgs):
default_detector = "s3fd" default_detector = "s3fd"
default_aligner = "fan" default_aligner = "fan"
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict( argument_list.append(dict(
opts=("-b", "--batch-mode"), opts=("-b", "--batch-mode"),
action="store_true", action="store_true",
@ -658,7 +657,7 @@ class ConvertArgs(ExtractConvertArgs):
"Conversion plugins can be configured in the 'Settings' Menu") "Conversion plugins can be configured in the 'Settings' Menu")
@staticmethod @staticmethod
def get_optional_arguments() -> List[Dict[str, Any]]: def get_optional_arguments() -> list[dict[str, T.Any]]:
""" Returns the argument list unique to the Convert command. """ Returns the argument list unique to the Convert command.
Returns Returns
@ -667,7 +666,7 @@ class ConvertArgs(ExtractConvertArgs):
The list of optional command line options for the Convert command The list of optional command line options for the Convert command
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict( argument_list.append(dict(
opts=("-ref", "--reference-video"), opts=("-ref", "--reference-video"),
action=FileFullPaths, action=FileFullPaths,
@ -915,7 +914,7 @@ class TrainArgs(FaceSwapArgs):
"Model plugins can be configured in the 'Settings' Menu") "Model plugins can be configured in the 'Settings' Menu")
@staticmethod @staticmethod
def get_argument_list() -> List[Dict[str, Any]]: def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for Train arguments. """ Returns the argument list for Train arguments.
Returns Returns
@ -923,7 +922,7 @@ class TrainArgs(FaceSwapArgs):
list list
The list of command line options for training The list of command line options for training
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict( argument_list.append(dict(
opts=("-A", "--input-A"), opts=("-A", "--input-A"),
action=DirFullPaths, action=DirFullPaths,
@ -1180,7 +1179,7 @@ class GuiArgs(FaceSwapArgs):
""" Creates the command line arguments for the GUI. """ """ Creates the command line arguments for the GUI. """
@staticmethod @staticmethod
def get_argument_list() -> List[Dict[str, Any]]: def get_argument_list() -> list[dict[str, T.Any]]:
""" Returns the argument list for GUI arguments. """ Returns the argument list for GUI arguments.
Returns Returns
@ -1188,7 +1187,7 @@ class GuiArgs(FaceSwapArgs):
list list
The list of command line options for the GUI The list of command line options for the GUI
""" """
argument_list: List[Dict[str, Any]] = [] argument_list: list[dict[str, T.Any]] = []
argument_list.append(dict( argument_list.append(dict(
opts=("-d", "--debug"), opts=("-d", "--debug"),
action="store_true", action="store_true",

View file

@ -1,20 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Launches the correct script with the given Command Line Arguments """ """ Launches the correct script with the given Command Line Arguments """
from __future__ import annotations
import logging import logging
import os import os
import platform import platform
import sys import sys
import typing as T
from importlib import import_module from importlib import import_module
from typing import Callable, TYPE_CHECKING
from lib.gpu_stats import set_exclude_devices, GPUStats from lib.gpu_stats import set_exclude_devices, GPUStats
from lib.logger import crash_log, log_setup from lib.logger import crash_log, log_setup
from lib.utils import (FaceswapError, get_backend, get_tf_version, from lib.utils import (FaceswapError, get_backend, get_tf_version,
safe_shutdown, set_backend, set_system_verbosity) safe_shutdown, set_backend, set_system_verbosity)
if TYPE_CHECKING: if T.TYPE_CHECKING:
import argparse import argparse
from collections.abc import Callable
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -99,8 +101,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
FaceswapError FaceswapError
If Tensorflow is not found, or is not between versions 2.4 and 2.9 If Tensorflow is not found, or is not between versions 2.4 and 2.9
""" """
directml_ver = rocm_ver = (2, 10) min_ver = (2, 10)
min_ver = (2, 7)
max_ver = (2, 10) max_ver = (2, 10)
try: try:
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
@ -120,7 +121,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
self._handle_import_error(msg) self._handle_import_error(msg)
tf_ver = get_tf_version() tf_ver = get_tf_version()
backend = get_backend()
if tf_ver < min_ver: if tf_ver < min_ver:
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version " msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
f"{tf_ver} installed. Please upgrade Tensorflow.") f"{tf_ver} installed. Please upgrade Tensorflow.")
@ -129,14 +129,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version " msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
f"{tf_ver} installed. Please downgrade Tensorflow.") f"{tf_ver} installed. Please downgrade Tensorflow.")
self._handle_import_error(msg) self._handle_import_error(msg)
if backend == "directml" and tf_ver != directml_ver:
msg = (f"The supported Tensorflow version for DirectML cards is {directml_ver} but "
f"you have version {tf_ver} installed. Please install the correct version.")
self._handle_import_error(msg)
if backend == "rocm" and tf_ver != rocm_ver:
msg = (f"The supported Tensorflow version for ROCm cards is {rocm_ver} but "
f"you have version {tf_ver} installed. Please install the correct version.")
self._handle_import_error(msg)
logger.debug("Installed Tensorflow Version: %s", tf_ver) logger.debug("Installed Tensorflow Version: %s", tf_ver)
@classmethod @classmethod
@ -209,7 +201,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
"See https://support.apple.com/en-gb/HT201341") "See https://support.apple.com/en-gb/HT201341")
raise FaceswapError("No display detected. GUI mode has been disabled.") raise FaceswapError("No display detected. GUI mode has been disabled.")
def execute_script(self, arguments: "argparse.Namespace") -> None: def execute_script(self, arguments: argparse.Namespace) -> None:
""" Performs final set up and launches the requested :attr:`_command` with the given """ Performs final set up and launches the requested :attr:`_command` with the given
command line arguments. command line arguments.
@ -250,7 +242,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
finally: finally:
safe_shutdown(got_error=not success) safe_shutdown(got_error=not success)
def _configure_backend(self, arguments: "argparse.Namespace") -> None: def _configure_backend(self, arguments: argparse.Namespace) -> None:
""" Configure the backend. """ Configure the backend.
Exclude any GPUs for use by Faceswap when requested. Exclude any GPUs for use by Faceswap when requested.

View file

@ -13,7 +13,6 @@ from collections import OrderedDict
from configparser import ConfigParser from configparser import ConfigParser
from dataclasses import dataclass from dataclasses import dataclass
from importlib import import_module from importlib import import_module
from typing import Dict, List, Optional, Tuple, Union
from lib.utils import full_path_split from lib.utils import full_path_split
@ -21,16 +20,11 @@ from lib.utils import full_path_split
_LANG = gettext.translation("lib.config", localedir="locales", fallback=True) _LANG = gettext.translation("lib.config", localedir="locales", fallback=True)
_ = _LANG.gettext _ = _LANG.gettext
# Can't type OrderedDict fully on Python 3.8 or lower OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
if sys.version_info < (3, 9): OrderedDictItemType = OrderedDict[str, "ConfigItem"]
OrderedDictSectionType = OrderedDict
OrderedDictItemType = OrderedDict
else:
OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
OrderedDictItemType = OrderedDict[str, "ConfigItem"]
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
ConfigValueType = Union[bool, int, float, List[str], str, None] ConfigValueType = bool | int | float | list[str] | str | None
@dataclass @dataclass
@ -60,11 +54,11 @@ class ConfigItem:
helptext: str helptext: str
datatype: type datatype: type
rounding: int rounding: int
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] min_max: tuple[int, int] | tuple[float, float] | None
choices: Union[str, List[str]] choices: str | list[str]
gui_radio: bool gui_radio: bool
fixed: bool fixed: bool
group: Optional[str] group: str | None
@dataclass @dataclass
@ -84,7 +78,7 @@ class ConfigSection:
class FaceswapConfig(): class FaceswapConfig():
""" Config Items """ """ Config Items """
def __init__(self, section: Optional[str], configfile: Optional[str] = None) -> None: def __init__(self, section: str | None, configfile: str | None = None) -> None:
""" Init Configuration """ Init Configuration
Parameters Parameters
@ -106,11 +100,11 @@ class FaceswapConfig():
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@property @property
def changeable_items(self) -> Dict[str, ConfigValueType]: def changeable_items(self) -> dict[str, ConfigValueType]:
""" Training only. """ Training only.
Return a dict of config items with their set values for items Return a dict of config items with their set values for items
that can be altered after the model has been created """ that can be altered after the model has been created """
retval: Dict[str, ConfigValueType] = {} retval: dict[str, ConfigValueType] = {}
sections = [sect for sect in self.config.sections() if sect.startswith("global")] sections = [sect for sect in self.config.sections() if sect.startswith("global")]
all_sections = sections if self.section is None else sections + [self.section] all_sections = sections if self.section is None else sections + [self.section]
for sect in all_sections: for sect in all_sections:
@ -189,10 +183,10 @@ class FaceswapConfig():
logger.debug("Added defaults: %s", section) logger.debug("Added defaults: %s", section)
@property @property
def config_dict(self) -> Dict[str, ConfigValueType]: def config_dict(self) -> dict[str, ConfigValueType]:
""" dict: Collate global options and requested section into a dictionary with the correct """ dict: Collate global options and requested section into a dictionary with the correct
data types """ data types """
conf: Dict[str, ConfigValueType] = {} conf: dict[str, ConfigValueType] = {}
sections = [sect for sect in self.config.sections() if sect.startswith("global")] sections = [sect for sect in self.config.sections() if sect.startswith("global")]
if self.section is not None: if self.section is not None:
sections.append(self.section) sections.append(self.section)
@ -240,7 +234,7 @@ class FaceswapConfig():
logger.debug("Returning item: (type: %s, value: %s)", datatype, retval) logger.debug("Returning item: (type: %s, value: %s)", datatype, retval)
return retval return retval
def _parse_list(self, section: str, option: str) -> List[str]: def _parse_list(self, section: str, option: str) -> list[str]:
""" Parse options that are stored as lists in the config file. These can be space or """ Parse options that are stored as lists in the config file. These can be space or
comma-separated items in the config file. They will be returned as a list of strings, comma-separated items in the config file. They will be returned as a list of strings,
regardless of what the final data type should be, so conversion from strings to other regardless of what the final data type should be, so conversion from strings to other
@ -268,7 +262,7 @@ class FaceswapConfig():
raw_option, retval, section, option) raw_option, retval, section, option)
return retval return retval
def _get_config_file(self, configfile: Optional[str]) -> str: def _get_config_file(self, configfile: str | None) -> str:
""" Return the config file from the calling folder or the provided file """ Return the config file from the calling folder or the provided file
Parameters Parameters
@ -309,17 +303,17 @@ class FaceswapConfig():
self.defaults[title] = ConfigSection(helptext=info, items=OrderedDict()) self.defaults[title] = ConfigSection(helptext=info, items=OrderedDict())
def add_item(self, def add_item(self,
section: Optional[str] = None, section: str | None = None,
title: Optional[str] = None, title: str | None = None,
datatype: type = str, datatype: type = str,
default: ConfigValueType = None, default: ConfigValueType = None,
info: Optional[str] = None, info: str | None = None,
rounding: Optional[int] = None, rounding: int | None = None,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None, min_max: tuple[int, int] | tuple[float, float] | None = None,
choices: Optional[Union[str, List[str]]] = None, choices: str | list[str] | None = None,
gui_radio: bool = False, gui_radio: bool = False,
fixed: bool = True, fixed: bool = True,
group: Optional[str] = None) -> None: group: str | None = None) -> None:
""" Add a default item to a config section """ Add a default item to a config section
For int or float values, rounding and min_max must be set For int or float values, rounding and min_max must be set
@ -382,10 +376,10 @@ class FaceswapConfig():
@classmethod @classmethod
def _expand_helptext(cls, def _expand_helptext(cls,
helptext: str, helptext: str,
choices: Union[str, List[str]], choices: str | list[str],
default: ConfigValueType, default: ConfigValueType,
datatype: type, datatype: type,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]], min_max: tuple[int, int] | tuple[float, float] | None,
fixed: bool) -> str: fixed: bool) -> str:
""" Add extra helptext info from parameters """ """ Add extra helptext info from parameters """
helptext += "\n" helptext += "\n"
@ -437,7 +431,7 @@ class FaceswapConfig():
def insert_config_section(self, def insert_config_section(self,
section: str, section: str,
helptext: str, helptext: str,
config: Optional[ConfigParser] = None) -> None: config: ConfigParser | None = None) -> None:
""" Insert a section into the config """ Insert a section into the config
Parameters Parameters
@ -464,7 +458,7 @@ class FaceswapConfig():
item: str, item: str,
default: ConfigValueType, default: ConfigValueType,
option: ConfigItem, option: ConfigItem,
config: Optional[ConfigParser] = None) -> None: config: ConfigParser | None = None) -> None:
""" Insert an item into a config section """ Insert an item into a config section
Parameters Parameters

View file

@ -1,23 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Converter for Faceswap """ """ Converter for Faceswap """
from __future__ import annotations
import logging import logging
import sys import typing as T
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, cast, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2 import cv2
import numpy as np import numpy as np
from plugins.plugin_loader import PluginLoader from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
from collections.abc import Callable
from lib.align.aligned_face import AlignedFace, CenteringType from lib.align.aligned_face import AlignedFace, CenteringType
from lib.align.detected_face import DetectedFace from lib.align.detected_face import DetectedFace
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
@ -46,10 +41,10 @@ class Adjustments:
sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional
The selected mask processing plugin. Default: `None` The selected mask processing plugin. Default: `None`
""" """
color: Optional["ColorAdjust"] = None color: ColorAdjust | None = None
mask: Optional["MaskAdjust"] = None mask: MaskAdjust | None = None
seamless: Optional["SeamlessAdjust"] = None seamless: SeamlessAdjust | None = None
sharpening: Optional["ScalingAdjust"] = None sharpening: ScalingAdjust | None = None
class Converter(): class Converter():
@ -81,11 +76,11 @@ class Converter():
def __init__(self, def __init__(self,
output_size: int, output_size: int,
coverage_ratio: float, coverage_ratio: float,
centering: "CenteringType", centering: CenteringType,
draw_transparent: bool, draw_transparent: bool,
pre_encode: Optional[Callable[[np.ndarray], List[bytes]]], pre_encode: Callable[[np.ndarray], list[bytes]] | None,
arguments: "Namespace", arguments: Namespace,
configfile: Optional[str] = None) -> None: configfile: str | None = None) -> None:
logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, " logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, "
"draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)", "draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)",
self.__class__.__name__, output_size, coverage_ratio, centering, self.__class__.__name__, output_size, coverage_ratio, centering,
@ -105,12 +100,12 @@ class Converter():
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@property @property
def cli_arguments(self) -> "Namespace": def cli_arguments(self) -> Namespace:
""":class:`argparse.Namespace`: The command line arguments passed to the convert """:class:`argparse.Namespace`: The command line arguments passed to the convert
process """ process """
return self._args return self._args
def reinitialize(self, config: "FaceswapConfig") -> None: def reinitialize(self, config: FaceswapConfig) -> None:
""" Reinitialize this :class:`Converter`. """ Reinitialize this :class:`Converter`.
Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the
@ -127,7 +122,7 @@ class Converter():
logger.debug("Reinitialized converter") logger.debug("Reinitialized converter")
def _load_plugins(self, def _load_plugins(self,
config: Optional["FaceswapConfig"] = None, config: FaceswapConfig | None = None,
disable_logging: bool = False) -> None: disable_logging: bool = False) -> None:
""" Load the requested adjustment plugins. """ Load the requested adjustment plugins.
@ -169,7 +164,7 @@ class Converter():
self._adjustments.sharpening = sharpening self._adjustments.sharpening = sharpening
logger.debug("Loaded plugins: %s", self._adjustments) logger.debug("Loaded plugins: %s", self._adjustments)
def process(self, in_queue: "EventQueue", out_queue: "EventQueue"): def process(self, in_queue: EventQueue, out_queue: EventQueue):
""" Main convert process. """ Main convert process.
Takes items from the in queue, runs the relevant adjustments, patches faces to final frame Takes items from the in queue, runs the relevant adjustments, patches faces to final frame
@ -188,7 +183,7 @@ class Converter():
in_queue, out_queue) in_queue, out_queue)
log_once = False log_once = False
while True: while True:
inbound: Union[Literal["EOF"], "ConvertItem", List["ConvertItem"]] = in_queue.get() inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get()
if inbound == "EOF": if inbound == "EOF":
logger.debug("EOF Received") logger.debug("EOF Received")
logger.debug("Patch queue finished") logger.debug("Patch queue finished")
@ -218,7 +213,7 @@ class Converter():
out_queue.put((item.inbound.filename, image)) out_queue.put((item.inbound.filename, image))
logger.debug("Completed convert process") logger.debug("Completed convert process")
def _patch_image(self, predicted: "ConvertItem") -> Union[np.ndarray, List[bytes]]: def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]:
""" Patch a swapped face onto a frame. """ Patch a swapped face onto a frame.
Run selected adjustments and swap the faces in a frame. Run selected adjustments and swap the faces in a frame.
@ -246,15 +241,15 @@ class Converter():
out=np.empty(patched_face.shape, dtype="uint8"), out=np.empty(patched_face.shape, dtype="uint8"),
casting='unsafe') casting='unsafe')
if self._writer_pre_encode is None: if self._writer_pre_encode is None:
retval: Union[np.ndarray, List[bytes]] = patched_face retval: np.ndarray | list[bytes] = patched_face
else: else:
retval = self._writer_pre_encode(patched_face) retval = self._writer_pre_encode(patched_face)
logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore
return retval return retval
def _get_new_image(self, def _get_new_image(self,
predicted: "ConvertItem", predicted: ConvertItem,
frame_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]:
""" Get the new face from the predictor and apply pre-warp manipulations. """ Get the new face from the predictor and apply pre-warp manipulations.
Applies any requested adjustments to the raw output of the Faceswap model Applies any requested adjustments to the raw output of the Faceswap model
@ -308,9 +303,9 @@ class Converter():
def _pre_warp_adjustments(self, def _pre_warp_adjustments(self,
new_face: np.ndarray, new_face: np.ndarray,
detected_face: "DetectedFace", detected_face: DetectedFace,
reference_face: "AlignedFace", reference_face: AlignedFace,
predicted_mask: Optional[np.ndarray]) -> np.ndarray: predicted_mask: np.ndarray | None) -> np.ndarray:
""" Run any requested adjustments that can be performed on the raw output from the Faceswap """ Run any requested adjustments that can be performed on the raw output from the Faceswap
model. model.
@ -337,7 +332,7 @@ class Converter():
""" """
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore
new_face.shape, predicted_mask.shape if predicted_mask is not None else None) new_face.shape, predicted_mask.shape if predicted_mask is not None else None)
old_face = cast(np.ndarray, reference_face.face)[..., :3] / 255.0 old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0
new_face, raw_mask = self._get_image_mask(new_face, new_face, raw_mask = self._get_image_mask(new_face,
detected_face, detected_face,
predicted_mask, predicted_mask,
@ -351,9 +346,9 @@ class Converter():
def _get_image_mask(self, def _get_image_mask(self,
new_face: np.ndarray, new_face: np.ndarray,
detected_face: "DetectedFace", detected_face: DetectedFace,
predicted_mask: Optional[np.ndarray], predicted_mask: np.ndarray | None,
reference_face: "AlignedFace") -> Tuple[np.ndarray, np.ndarray]: reference_face: AlignedFace) -> tuple[np.ndarray, np.ndarray]:
""" Return any selected image mask """ Return any selected image mask
Places the requested mask into the new face's Alpha channel. Places the requested mask into the new face's Alpha channel.

View file

@ -5,11 +5,10 @@ from the :class:`_GPUStats` class contained here. """
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional
from lib.utils import get_backend from lib.utils import get_backend
_EXCLUDE_DEVICES: List[int] = [] _EXCLUDE_DEVICES: list[int] = []
@dataclass @dataclass
@ -29,11 +28,11 @@ class GPUInfo():
devices_active: list[int] devices_active: list[int]
List of integers representing the indices of the active GPU devices. List of integers representing the indices of the active GPU devices.
""" """
vram: List[int] vram: list[int]
vram_free: List[int] vram_free: list[int]
driver: str driver: str
devices: List[str] devices: list[str]
devices_active: List[int] devices_active: list[int]
@dataclass @dataclass
@ -57,7 +56,7 @@ class BiggestGPUInfo():
total: float total: float
def set_exclude_devices(devices: List[int]) -> None: def set_exclude_devices(devices: list[int]) -> None:
""" Add any explicitly selected GPU devices to the global list of devices to be excluded """ Add any explicitly selected GPU devices to the global list of devices to be excluded
from use by Faceswap. from use by Faceswap.
@ -89,19 +88,19 @@ class _GPUStats():
def __init__(self, log: bool = True) -> None: def __init__(self, log: bool = True) -> None:
# Logger is held internally, as we don't want to log when obtaining system stats on crash # Logger is held internally, as we don't want to log when obtaining system stats on crash
# or when querying the backend for command line options # or when querying the backend for command line options
self._logger: Optional[logging.Logger] = logging.getLogger(__name__) if log else None self._logger: logging.Logger | None = logging.getLogger(__name__) if log else None
self._log("debug", f"Initializing {self.__class__.__name__}") self._log("debug", f"Initializing {self.__class__.__name__}")
self._is_initialized = False self._is_initialized = False
self._initialize() self._initialize()
self._device_count: int = self._get_device_count() self._device_count: int = self._get_device_count()
self._active_devices: List[int] = self._get_active_devices() self._active_devices: list[int] = self._get_active_devices()
self._handles: list = self._get_handles() self._handles: list = self._get_handles()
self._driver: str = self._get_driver() self._driver: str = self._get_driver()
self._device_names: List[str] = self._get_device_names() self._device_names: list[str] = self._get_device_names()
self._vram: List[int] = self._get_vram() self._vram: list[int] = self._get_vram()
self._vram_free: List[int] = self._get_free_vram() self._vram_free: list[int] = self._get_free_vram()
if get_backend() != "cpu" and not self._active_devices: if get_backend() != "cpu" and not self._active_devices:
self._log("warning", "No GPU detected") self._log("warning", "No GPU detected")
@ -115,7 +114,7 @@ class _GPUStats():
return self._device_count return self._device_count
@property @property
def cli_devices(self) -> List[str]: def cli_devices(self) -> list[str]:
""" list[str]: Formatted index: name text string for each GPU """ """ list[str]: Formatted index: name text string for each GPU """
return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)] return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)]
@ -167,7 +166,7 @@ class _GPUStats():
""" """
raise NotImplementedError() raise NotImplementedError()
def _get_active_devices(self) -> List[int]: def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded in """ Obtain the indices of active GPUs (those that have not been explicitly excluded in
the command line arguments). the command line arguments).
@ -204,7 +203,7 @@ class _GPUStats():
""" """
raise NotImplementedError() raise NotImplementedError()
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Override to obtain the names of all connected GPUs. The quality of this information """ Override to obtain the names of all connected GPUs. The quality of this information
depends on the backend and OS being used, but it should be sufficient for identifying depends on the backend and OS being used, but it should be sufficient for identifying
cards. cards.
@ -217,7 +216,7 @@ class _GPUStats():
""" """
raise NotImplementedError() raise NotImplementedError()
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Override to obtain the total VRAM in Megabytes for each connected GPU. """ Override to obtain the total VRAM in Megabytes for each connected GPU.
Returns Returns
@ -228,7 +227,7 @@ class _GPUStats():
""" """
raise NotImplementedError() raise NotImplementedError()
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Override to obtain the amount of VRAM that is available, in Megabytes, for each """ Override to obtain the amount of VRAM that is available, in Megabytes, for each
connected GPU. connected GPU.

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """ """ Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """
from typing import Any, List import typing as T
import os import os
import psutil import psutil
@ -35,7 +35,7 @@ class AppleSiliconStats(_GPUStats):
""" """
def __init__(self, log: bool = True) -> None: def __init__(self, log: bool = True) -> None:
# Following attribute set in :func:``_initialize`` # Following attribute set in :func:``_initialize``
self._tf_devices: List[Any] = [] self._tf_devices: list[T.Any] = []
super().__init__(log=log) super().__init__(log=log)
@ -142,7 +142,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}") self._log("debug", f"GPU Driver: {driver}")
return driver return driver
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of available Apple Silicon SoC(s) as identified in """ Obtain the list of names of available Apple Silicon SoC(s) as identified in
:attr:`_handles`. :attr:`_handles`.
@ -155,7 +155,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}") self._log("debug", f"GPU Devices: {names}")
return names return names
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in """ Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in
:attr:`_handles`. :attr:`_handles`.
@ -175,7 +175,7 @@ class AppleSiliconStats(_GPUStats):
self._log("debug", f"SoC RAM: {vram}") self._log("debug", f"SoC RAM: {vram}")
return vram return vram
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each available Apple """ Obtain the amount of VRAM that is available, in Megabytes, for each available Apple
Silicon SoC. Silicon SoC.

View file

@ -1,9 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Dummy functions for running faceswap on CPU. """ """ Dummy functions for running faceswap on CPU. """
from typing import List
from ._base import _GPUStats from ._base import _GPUStats
@ -65,7 +61,7 @@ class CPUStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}") self._log("debug", f"GPU Driver: {driver}")
return driver return driver
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`. """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns Returns
@ -73,11 +69,11 @@ class CPUStats(_GPUStats):
list list
An empty list for CPU backends An empty list for CPU backends
""" """
names: List[str] = [] names: list[str] = []
self._log("debug", f"GPU Devices: {names}") self._log("debug", f"GPU Devices: {names}")
return names return names
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the RAM in Megabytes for the running system. """ Obtain the RAM in Megabytes for the running system.
Returns Returns
@ -85,11 +81,11 @@ class CPUStats(_GPUStats):
list list
An empty list for CPU backends An empty list for CPU backends
""" """
vram: List[int] = [] vram: list[int] = []
self._log("debug", f"GPU VRAM: {vram}") self._log("debug", f"GPU VRAM: {vram}")
return vram return vram
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of RAM that is available, in Megabytes, for the running system. """ Obtain the amount of RAM that is available, in Megabytes, for the running system.
Returns Returns
@ -97,6 +93,6 @@ class CPUStats(_GPUStats):
list list
An empty list for CPU backends An empty list for CPU backends
""" """
vram: List[int] = [] vram: list[int] = []
self._log("debug", f"GPU VRAM free: {vram}") self._log("debug", f"GPU VRAM free: {vram}")
return vram return vram

View file

@ -1,19 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Collects and returns Information on DirectX 12 hardware devices for DirectML. """ """ Collects and returns Information on DirectX 12 hardware devices for DirectML. """
from __future__ import annotations
import os import os
import sys import sys
import typing as T
assert sys.platform == "win32" assert sys.platform == "win32"
import ctypes import ctypes
from ctypes import POINTER, Structure, windll from ctypes import POINTER, Structure, windll
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
from typing import Any, Callable, cast, List
from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error
from ._base import _GPUStats from ._base import _GPUStats
if T.TYPE_CHECKING:
from collections.abc import Callable
# Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types # Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types
# We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we # We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we
# attach it directly and suck up the typing errors. # attach it directly and suck up the typing errors.
@ -314,7 +318,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._adapters = self._get_adapters() self._adapters = self._get_adapters()
self._devices = self._process_adapters() self._devices = self._process_adapters()
self._valid_adaptors: List[Device] = [] self._valid_adaptors: list[Device] = []
self._log("debug", f"Initialized {self.__class__.__name__}") self._log("debug", f"Initialized {self.__class__.__name__}")
def _get_factory(self) -> ctypes._Pointer: def _get_factory(self) -> ctypes._Pointer:
@ -334,12 +338,12 @@ class Adapters(): # pylint:disable=too-few-public-methods
factory_func.restype = HRESULT factory_func.restype = HRESULT
handle = ctypes.c_void_p(0) handle = ctypes.c_void_p(0)
factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access
retval = ctypes.POINTER(IDXGIFactory6)(cast(IDXGIFactory6, handle.value)) retval = ctypes.POINTER(IDXGIFactory6)(T.cast(IDXGIFactory6, handle.value))
self._log("debug", f"factory: {retval}") self._log("debug", f"factory: {retval}")
return retval return retval
@property @property
def valid_adapters(self) -> List[Device]: def valid_adapters(self) -> list[Device]:
""" list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """ """ list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """
if self._valid_adaptors: if self._valid_adaptors:
return self._valid_adaptors return self._valid_adaptors
@ -354,7 +358,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._log("debug", f"valid_adaptors: {self._valid_adaptors}") self._log("debug", f"valid_adaptors: {self._valid_adaptors}")
return self._valid_adaptors return self._valid_adaptors
def _get_adapters(self) -> List[ctypes._Pointer]: def _get_adapters(self) -> list[ctypes._Pointer]:
""" Obtain DirectX 12 supporting hardware adapter objects and add a Device class for """ Obtain DirectX 12 supporting hardware adapter objects and add a Device class for
obtaining details obtaining details
@ -376,7 +380,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
if success != 0: if success != 0:
raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: " raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: "
f"{hex(ctypes.c_ulong(success).value)}") f"{hex(ctypes.c_ulong(success).value)}")
adapter = POINTER(IDXGIAdapter3)(cast(IDXGIAdapter3, handle.value)) adapter = POINTER(IDXGIAdapter3)(T.cast(IDXGIAdapter3, handle.value))
self._log("debug", f"found adapter: {adapter}") self._log("debug", f"found adapter: {adapter}")
retval.append(adapter) retval.append(adapter)
except COMError as err: except COMError as err:
@ -392,7 +396,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
self._log("debug", f"adapters: {retval}") self._log("debug", f"adapters: {retval}")
return retval return retval
def _query_adapter(self, func: Callable[[Any], Any], *args: Any) -> None: def _query_adapter(self, func: Callable[[T.Any], T.Any], *args: T.Any) -> None:
""" Query an adapter function, logging if the HRESULT is not a success """ Query an adapter function, logging if the HRESULT is not a success
Parameters Parameters
@ -430,7 +434,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
LookupGUID.ID3D12Device) LookupGUID.ID3D12Device)
return success in (0, 1) return success in (0, 1)
def _process_adapters(self) -> List[Device]: def _process_adapters(self) -> list[Device]:
""" Process the adapters to add discovered information. """ Process the adapters to add discovered information.
Returns Returns
@ -485,21 +489,21 @@ class DirectML(_GPUStats):
Default: ``True`` Default: ``True``
""" """
def __init__(self, log: bool = True) -> None: def __init__(self, log: bool = True) -> None:
self._devices: List[Device] = [] self._devices: list[Device] = []
super().__init__(log=log) super().__init__(log=log)
@property @property
def _all_vram(self) -> List[int]: def _all_vram(self) -> list[int]:
""" list: The VRAM of each GPU device that the DX API has discovered. """ """ list: The VRAM of each GPU device that the DX API has discovered. """
return [int(device.description.DedicatedVideoMemory / (1024 * 1024)) return [int(device.description.DedicatedVideoMemory / (1024 * 1024))
for device in self._devices] for device in self._devices]
@property @property
def names(self) -> List[str]: def names(self) -> list[str]:
""" list: The name of each GPU device that the DX API has discovered. """ """ list: The name of each GPU device that the DX API has discovered. """
return [device.description.Description for device in self._devices] return [device.description.Description for device in self._devices]
def _get_active_devices(self) -> List[int]: def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by """ Obtain the indices of active GPUs (those that have not been explicitly excluded by
DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments). arguments).
@ -517,7 +521,7 @@ class DirectML(_GPUStats):
self._log("debug", f"Active GPU Devices: {devices}") self._log("debug", f"Active GPU Devices: {devices}")
return devices return devices
def _get_devices(self) -> List[Device]: def _get_devices(self) -> list[Device]:
""" Obtain all detected DX API devices. """ Obtain all detected DX API devices.
Returns Returns
@ -582,7 +586,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU Drivers: {drivers}") self._log("debug", f"GPU Drivers: {drivers}")
return drivers return drivers
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`. """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns Returns
@ -594,7 +598,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU Devices: {names}") self._log("debug", f"GPU Devices: {names}")
return names return names
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in """ Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in
:attr:`_handles`. :attr:`_handles`.
@ -607,7 +611,7 @@ class DirectML(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}") self._log("debug", f"GPU VRAM: {vram}")
return vram return vram
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX """ Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX
12 supporting GPU. 12 supporting GPU.

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Collects and returns Information on available Nvidia GPUs. """ """ Collects and returns Information on available Nvidia GPUs. """
import os import os
from typing import List
import pynvml import pynvml
@ -83,7 +82,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Device count: {retval}") self._log("debug", f"GPU Device count: {retval}")
return retval return retval
def _get_active_devices(self) -> List[int]: def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by """ Obtain the indices of active GPUs (those that have not been explicitly excluded by
CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments). arguments).
@ -130,7 +129,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}") self._log("debug", f"GPU Driver: {driver}")
return driver return driver
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`. """ Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
Returns Returns
@ -143,7 +142,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}") self._log("debug", f"GPU Devices: {names}")
return names return names
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in """ Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`. :attr:`_handles`.
@ -157,7 +156,7 @@ class NvidiaStats(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}") self._log("debug", f"GPU VRAM: {vram}")
return vram return vram
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia """ Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU. GPU.

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Collects and returns Information on available Nvidia GPUs connected to Apple Macs. """ """ Collects and returns Information on available Nvidia GPUs connected to Apple Macs. """
from typing import List
import pynvx import pynvx
from lib.utils import FaceswapError from lib.utils import FaceswapError
@ -92,7 +90,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU Driver: {driver}") self._log("debug", f"GPU Driver: {driver}")
return driver return driver
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`. """ Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
Returns Returns
@ -105,7 +103,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU Devices: {names}") self._log("debug", f"GPU Devices: {names}")
return names return names
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in """ Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
:attr:`_handles`. :attr:`_handles`.
@ -120,7 +118,7 @@ class NvidiaAppleStats(_GPUStats):
self._log("debug", f"GPU VRAM: {vram}") self._log("debug", f"GPU VRAM: {vram}")
return vram return vram
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia """ Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
GPU. GPU.

View file

@ -10,7 +10,6 @@ It is a good starting point but may need to be refined over time
import os import os
import re import re
from subprocess import run from subprocess import run
from typing import List
from ._base import _GPUStats from ._base import _GPUStats
@ -221,7 +220,7 @@ class ROCm(_GPUStats):
""" """
def __init__(self, log: bool = True) -> None: def __init__(self, log: bool = True) -> None:
self._vendor_id = "0x1002" # AMD VendorID self._vendor_id = "0x1002" # AMD VendorID
self._sysfs_paths: List[str] = [] self._sysfs_paths: list[str] = []
super().__init__(log=log) super().__init__(log=log)
def _from_sysfs_file(self, path: str) -> str: def _from_sysfs_file(self, path: str) -> str:
@ -249,7 +248,7 @@ class ROCm(_GPUStats):
val = "" val = ""
return val return val
def _get_sysfs_paths(self) -> List[str]: def _get_sysfs_paths(self) -> list[str]:
""" Obtain a list of sysfs paths to AMD branded GPUs connected to the system """ Obtain a list of sysfs paths to AMD branded GPUs connected to the system
Returns Returns
@ -259,7 +258,7 @@ class ROCm(_GPUStats):
""" """
base_dir = "/sys/class/drm/" base_dir = "/sys/class/drm/"
retval: List[str] = [] retval: list[str] = []
if not os.path.exists(base_dir): if not os.path.exists(base_dir):
self._log("warning", f"sysfs not found at '{base_dir}'") self._log("warning", f"sysfs not found at '{base_dir}'")
return retval return retval
@ -347,7 +346,7 @@ class ROCm(_GPUStats):
self._log("debug", f"GPU Drivers: {retval}") self._log("debug", f"GPU Drivers: {retval}")
return retval return retval
def _get_device_names(self) -> List[str]: def _get_device_names(self) -> list[str]:
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`. """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
Returns Returns
@ -383,7 +382,7 @@ class ROCm(_GPUStats):
self._log("debug", f"Device names: {retval}") self._log("debug", f"Device names: {retval}")
return retval return retval
def _get_active_devices(self) -> List[int]: def _get_active_devices(self) -> list[int]:
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by """ Obtain the indices of active GPUs (those that have not been explicitly excluded by
HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
arguments). arguments).
@ -401,7 +400,7 @@ class ROCm(_GPUStats):
self._log("debug", f"Active GPU Devices: {devices}") self._log("debug", f"Active GPU Devices: {devices}")
return devices return devices
def _get_vram(self) -> List[int]: def _get_vram(self) -> list[int]:
""" Obtain the VRAM in Megabytes for each connected AMD GPU as identified in """ Obtain the VRAM in Megabytes for each connected AMD GPU as identified in
:attr:`_handles`. :attr:`_handles`.
@ -423,7 +422,7 @@ class ROCm(_GPUStats):
self._log("debug", f"GPU VRAM: {retval}") self._log("debug", f"GPU VRAM: {retval}")
return retval return retval
def _get_free_vram(self) -> List[int]: def _get_free_vram(self) -> list[int]:
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD """ Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD
GPU. GPU.

View file

@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Handles the loading and collation of events from Tensorflow event log files. """ """ Handles the loading and collation of events from Tensorflow event log files. """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
import zlib import zlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, cast, Dict, Iterator, Generator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -17,11 +16,8 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
from lib.serializer import get_serializer from lib.serializer import get_serializer
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal from collections.abc import Generator, Iterator
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -38,7 +34,7 @@ class EventData:
The loss values collected for A and B sides for the event step The loss values collected for A and B sides for the event step
""" """
timestamp: float = 0.0 timestamp: float = 0.0
loss: List[float] = field(default_factory=list) loss: list[float] = field(default_factory=list)
class _LogFiles(): class _LogFiles():
@ -56,11 +52,11 @@ class _LogFiles():
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@property @property
def session_ids(self) -> List[int]: def session_ids(self) -> list[int]:
""" list[int]: Sorted list of `ints` of available session ids. """ """ list[int]: Sorted list of `ints` of available session ids. """
return list(sorted(self._filenames)) return list(sorted(self._filenames))
def _get_log_filenames(self) -> Dict[int, str]: def _get_log_filenames(self) -> dict[int, str]:
""" Get the Tensorflow event filenames for all existing sessions. """ Get the Tensorflow event filenames for all existing sessions.
Returns Returns
@ -69,7 +65,7 @@ class _LogFiles():
The full path of each log file for each training session id that has been run The full path of each log file for each training session id that has been run
""" """
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder) logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
retval: Dict[int, str] = {} retval: dict[int, str] = {}
for dirpath, _, filenames in os.walk(self._logs_folder): for dirpath, _, filenames in os.walk(self._logs_folder):
if not any(filename.startswith("events.out.tfevents") for filename in filenames): if not any(filename.startswith("events.out.tfevents") for filename in filenames):
continue continue
@ -82,7 +78,7 @@ class _LogFiles():
return retval return retval
@classmethod @classmethod
def _get_session_id(cls, folder: str) -> Optional[int]: def _get_session_id(cls, folder: str) -> int | None:
""" Obtain the session id for the given folder. """ Obtain the session id for the given folder.
Parameters Parameters
@ -103,7 +99,7 @@ class _LogFiles():
return retval return retval
@classmethod @classmethod
def _get_log_filename(cls, folder: str, filenames: List[str]) -> str: def _get_log_filename(cls, folder: str, filenames: list[str]) -> str:
""" Obtain the session log file for the given folder. If multiple log files exist for the """ Obtain the session log file for the given folder. If multiple log files exist for the
given folder, then the most recent log file is used, as earlier files are assumed to be given folder, then the most recent log file is used, as earlier files are assumed to be
obsolete. obsolete.
@ -161,10 +157,10 @@ class _CacheData():
loss: :class:`np.ndarray` loss: :class:`np.ndarray`
The loss values collected for A and B sides for the session The loss values collected for A and B sides for the session
""" """
def __init__(self, labels: List[str], timestamps: np.ndarray, loss: np.ndarray) -> None: def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
self.labels = labels self.labels = labels
self._loss = zlib.compress(cast(bytes, loss)) self._loss = zlib.compress(T.cast(bytes, loss))
self._timestamps = zlib.compress(cast(bytes, timestamps)) self._timestamps = zlib.compress(T.cast(bytes, timestamps))
self._timestamps_shape = timestamps.shape self._timestamps_shape = timestamps.shape
self._loss_shape = loss.shape self._loss_shape = loss.shape
@ -192,8 +188,8 @@ class _CacheData():
timestamps: :class:`numpy.ndarray` timestamps: :class:`numpy.ndarray`
The latest timestamps to add to the cache The latest timestamps to add to the cache
""" """
new_buffer: List[bytes] = [] new_buffer: list[bytes] = []
new_shapes: List[Tuple[int, ...]] = [] new_shapes: list[tuple[int, ...]] = []
for data, buffer, dtype, shape in zip([timestamps, loss], for data, buffer, dtype, shape in zip([timestamps, loss],
[self._timestamps, self._loss], [self._timestamps, self._loss],
["float64", "float32"], ["float64", "float32"],
@ -220,9 +216,9 @@ class _Cache():
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """ """ Holds parsed Tensorflow log event data in a compressed cache in memory. """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__) logger.debug("Initializing: %s", self.__class__.__name__)
self._data: Dict[int, _CacheData] = {} self._data: dict[int, _CacheData] = {}
self._carry_over: Dict[int, EventData] = {} self._carry_over: dict[int, EventData] = {}
self._loss_labels: List[str] = [] self._loss_labels: list[str] = []
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
def is_cached(self, session_id: int) -> bool: def is_cached(self, session_id: int) -> bool:
@ -242,8 +238,8 @@ class _Cache():
def cache_data(self, def cache_data(self,
session_id: int, session_id: int,
data: Dict[int, EventData], data: dict[int, EventData],
labels: List[str], labels: list[str],
is_live: bool = False) -> None: is_live: bool = False) -> None:
""" Add a full session's worth of event data to :attr:`_data`. """ Add a full session's worth of event data to :attr:`_data`.
@ -278,8 +274,8 @@ class _Cache():
self._add_latest_live(session_id, loss, timestamps) self._add_latest_live(session_id, loss, timestamps)
def _to_numpy(self, def _to_numpy(self,
data: Dict[int, EventData], data: dict[int, EventData],
is_live: bool) -> Tuple[np.ndarray, np.ndarray]: is_live: bool) -> tuple[np.ndarray, np.ndarray]:
""" Extract each individual step data into separate numpy arrays for loss and timestamps. """ Extract each individual step data into separate numpy arrays for loss and timestamps.
Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays
@ -333,7 +329,7 @@ class _Cache():
return n_times, n_loss return n_times, n_loss
def _collect_carry_over(self, data: Dict[int, EventData]) -> None: def _collect_carry_over(self, data: dict[int, EventData]) -> None:
""" For live data, collect carried over data from the previous update and merge into the """ For live data, collect carried over data from the previous update and merge into the
current data dictionary. current data dictionary.
@ -357,8 +353,8 @@ class _Cache():
logger.debug("Merged carry over data: %s", update) logger.debug("Merged carry over data: %s", update)
def _process_data(self, def _process_data(self,
data: Dict[int, EventData], data: dict[int, EventData],
is_live: bool) -> Tuple[List[float], List[List[float]]]: is_live: bool) -> tuple[list[float], list[list[float]]]:
""" Process live update data. """ Process live update data.
Live data requires different processing as often we will only have partial data for the Live data requires different processing as often we will only have partial data for the
@ -383,8 +379,8 @@ class _Cache():
timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss) timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss)
for idx in sorted(data)]) for idx in sorted(data)])
l_loss: List[List[float]] = list(loss) l_loss: list[list[float]] = list(loss)
l_timestamps: List[float] = list(timestamps) l_timestamps: list[float] = list(timestamps)
if len(l_loss[-1]) != len(self._loss_labels): if len(l_loss[-1]) != len(self._loss_labels):
logger.debug("Truncated loss found. loss count: %s", len(l_loss)) logger.debug("Truncated loss found. loss count: %s", len(l_loss))
@ -418,8 +414,8 @@ class _Cache():
self._data[session_id].add_live_data(timestamps, loss) self._data[session_id].add_live_data(timestamps, loss)
def get_data(self, session_id: int, metric: Literal["loss", "timestamps"] def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"]
) -> Optional[Dict[int, Dict[str, Union[np.ndarray, List[str]]]]]: ) -> dict[int, dict[str, np.ndarray | list[str]]] | None:
""" Retrieve the decompressed cached data from the cache for the given session id. """ Retrieve the decompressed cached data from the cache for the given session id.
Parameters Parameters
@ -445,10 +441,10 @@ class _Cache():
return None return None
raw = {session_id: data} raw = {session_id: data}
retval: Dict[int, Dict[str, Union[np.ndarray, List[str]]]] = {} retval: dict[int, dict[str, np.ndarray | list[str]]] = {}
for idx, data in raw.items(): for idx, data in raw.items():
array = data.loss if metric == "loss" else data.timestamps array = data.loss if metric == "loss" else data.timestamps
val: Dict[str, Union[np.ndarray, List[str]]] = {str(metric): array} val: dict[str, np.ndarray | list[str]] = {str(metric): array}
if metric == "loss": if metric == "loss":
val["labels"] = data.labels val["labels"] = data.labels
retval[idx] = val retval[idx] = val
@ -488,7 +484,7 @@ class TensorBoardLogs():
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@property @property
def session_ids(self) -> List[int]: def session_ids(self) -> list[int]:
""" list[int]: Sorted list of integers of available session ids. """ """ list[int]: Sorted list of integers of available session ids. """
return self._log_files.session_ids return self._log_files.session_ids
@ -539,7 +535,7 @@ class TensorBoardLogs():
parser = _EventParser(iterator, self._cache, live_data) parser = _EventParser(iterator, self._cache, live_data)
parser.cache_events(session_id) parser.cache_events(session_id)
def _check_cache(self, session_id: Optional[int] = None) -> None: def _check_cache(self, session_id: int | None = None) -> None:
""" Check if the given session_id has been cached and if not, cache it. """ Check if the given session_id has been cached and if not, cache it.
Parameters Parameters
@ -557,7 +553,7 @@ class TensorBoardLogs():
if not self._cache.is_cached(idx): if not self._cache.is_cached(idx):
self._cache_data(idx) self._cache_data(idx)
def get_loss(self, session_id: Optional[int] = None) -> Dict[int, Dict[str, np.ndarray]]: def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]:
""" Read the loss from the TensorBoard event logs """ Read the loss from the TensorBoard event logs
Parameters Parameters
@ -573,7 +569,7 @@ class TensorBoardLogs():
and list of loss values for each step and list of loss values for each step
""" """
logger.debug("Getting loss: (session_id: %s)", session_id) logger.debug("Getting loss: (session_id: %s)", session_id)
retval: Dict[int, Dict[str, np.ndarray]] = {} retval: dict[int, dict[str, np.ndarray]] = {}
for idx in [session_id] if session_id else self.session_ids: for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx) self._check_cache(idx)
full_data = self._cache.get_data(idx, "loss") full_data = self._cache.get_data(idx, "loss")
@ -588,7 +584,7 @@ class TensorBoardLogs():
for key, val in retval.items()}) for key, val in retval.items()})
return retval return retval
def get_timestamps(self, session_id: Optional[int] = None) -> Dict[int, np.ndarray]: def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]:
""" Read the timestamps from the TensorBoard logs. """ Read the timestamps from the TensorBoard logs.
As loss timestamps are slightly different for each loss, we collect the timestamp from the As loss timestamps are slightly different for each loss, we collect the timestamp from the
@ -608,7 +604,7 @@ class TensorBoardLogs():
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)", logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
session_id, self._is_training) session_id, self._is_training)
retval: Dict[int, np.ndarray] = {} retval: dict[int, np.ndarray] = {}
for idx in [session_id] if session_id else self.session_ids: for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx) self._check_cache(idx)
data = self._cache.get_data(idx, "timestamps") data = self._cache.get_data(idx, "timestamps")
@ -640,7 +636,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
self._live_data = live_data self._live_data = live_data
self._cache = cache self._cache = cache
self._iterator = self._get_latest_live(iterator) if live_data else iterator self._iterator = self._get_latest_live(iterator) if live_data else iterator
self._loss_labels: List[str] = [] self._loss_labels: list[str] = []
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@classmethod @classmethod
@ -683,7 +679,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
The session id that the data is being cached for The session id that the data is being cached for
""" """
assert self._iterator is not None assert self._iterator is not None
data: Dict[int, EventData] = {} data: dict[int, EventData] = {}
try: try:
for record in self._iterator: for record in self._iterator:
event = event_pb2.Event.FromString(record) # pylint:disable=no-member event = event_pb2.Event.FromString(record) # pylint:disable=no-member
@ -743,7 +739,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
logger.debug("Collated loss labels: %s", self._loss_labels) logger.debug("Collated loss labels: %s", self._loss_labels)
@classmethod @classmethod
def _get_outputs(cls, model_config: Dict[str, Any]) -> np.ndarray: def _get_outputs(cls, model_config: dict[str, T.Any]) -> np.ndarray:
""" Obtain the output names, instance index and output index for the given model. """ Obtain the output names, instance index and output index for the given model.
If there is only a single output, the shape of the array is expanded to remain consistent If there is only a single output, the shape of the array is expanded to remain consistent

View file

@ -5,15 +5,15 @@ Holds the globally loaded training session. This will either be a user selected
the analysis tab) or the currently training session. the analysis tab) or the currently training session.
""" """
from __future__ import annotations
import logging import logging
import time
import os import os
import time
import typing as T
import warnings import warnings
from math import ceil from math import ceil
from threading import Event from threading import Event
from typing import Any, cast, Dict, List, Optional, overload, Tuple, Union
import numpy as np import numpy as np
@ -31,12 +31,12 @@ class GlobalSession():
""" """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
self._state: Dict[str, Any] = {} self._state: dict[str, T.Any] = {}
self._model_dir = "" self._model_dir = ""
self._model_name = "" self._model_name = ""
self._tb_logs: Optional[TensorBoardLogs] = None self._tb_logs: TensorBoardLogs | None = None
self._summary: Optional[SessionsSummary] = None self._summary: SessionsSummary | None = None
self._is_training = False self._is_training = False
self._is_querying = Event() self._is_querying = Event()
@ -60,7 +60,7 @@ class GlobalSession():
return os.path.join(self._model_dir, self._model_name) return os.path.join(self._model_dir, self._model_name)
@property @property
def batch_sizes(self) -> Dict[int, int]: def batch_sizes(self) -> dict[int, int]:
""" dict: The batch sizes for each session_id for the model. """ """ dict: The batch sizes for each session_id for the model. """
if not self._state: if not self._state:
return {} return {}
@ -68,7 +68,7 @@ class GlobalSession():
for sess_id, sess in self._state.get("sessions", {}).items()} for sess_id, sess in self._state.get("sessions", {}).items()}
@property @property
def full_summary(self) -> List[dict]: def full_summary(self) -> list[dict]:
""" list: List of dictionaries containing summary statistics for each session id. """ """ list: List of dictionaries containing summary statistics for each session id. """
assert self._summary is not None assert self._summary is not None
return self._summary.get_summary_stats() return self._summary.get_summary_stats()
@ -83,7 +83,7 @@ class GlobalSession():
return self._state["sessions"][max_id]["no_logs"] return self._state["sessions"][max_id]["no_logs"]
@property @property
def session_ids(self) -> List[int]: def session_ids(self) -> list[int]:
""" list: The sorted list of all existing session ids in the state file """ """ list: The sorted list of all existing session ids in the state file """
if self._tb_logs is None: if self._tb_logs is None:
return [] return []
@ -164,7 +164,7 @@ class GlobalSession():
self._is_training = False self._is_training = False
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]: def get_loss(self, session_id: int | None) -> dict[str, np.ndarray]:
""" Obtain the loss values for the given session_id. """ Obtain the loss values for the given session_id.
Parameters Parameters
@ -186,11 +186,11 @@ class GlobalSession():
assert self._tb_logs is not None assert self._tb_logs is not None
loss_dict = self._tb_logs.get_loss(session_id=session_id) loss_dict = self._tb_logs.get_loss(session_id=session_id)
if session_id is None: if session_id is None:
all_loss: Dict[str, List[float]] = {} all_loss: dict[str, list[float]] = {}
for key in sorted(loss_dict): for key in sorted(loss_dict):
for loss_key, loss in loss_dict[key].items(): for loss_key, loss in loss_dict[key].items():
all_loss.setdefault(loss_key, []).extend(loss) all_loss.setdefault(loss_key, []).extend(loss)
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32") retval: dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
for key, val in all_loss.items()} for key, val in all_loss.items()}
else: else:
retval = loss_dict.get(session_id, {}) retval = loss_dict.get(session_id, {})
@ -199,11 +199,11 @@ class GlobalSession():
self._is_querying.clear() self._is_querying.clear()
return retval return retval
@overload @T.overload
def get_timestamps(self, session_id: None) -> Dict[int, np.ndarray]: def get_timestamps(self, session_id: None) -> dict[int, np.ndarray]:
... ...
@overload @T.overload
def get_timestamps(self, session_id: int) -> np.ndarray: def get_timestamps(self, session_id: int) -> np.ndarray:
... ...
@ -247,7 +247,7 @@ class GlobalSession():
continue continue
break break
def get_loss_keys(self, session_id: Optional[int]) -> List[str]: def get_loss_keys(self, session_id: int | None) -> list[str]:
""" Obtain the loss keys for the given session_id. """ Obtain the loss keys for the given session_id.
Parameters Parameters
@ -268,7 +268,7 @@ class GlobalSession():
in self._tb_logs.get_loss(session_id=session_id).items()} in self._tb_logs.get_loss(session_id=session_id).items()}
if session_id is None: if session_id is None:
retval: List[str] = list(set(loss_key retval: list[str] = list(set(loss_key
for session in loss_keys.values() for session in loss_keys.values()
for loss_key in session)) for loss_key in session))
else: else:
@ -293,11 +293,11 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
self._session = session self._session = session
self._state = session._state self._state = session._state
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {} self._time_stats: dict[int, dict[str, float | int]] = {}
self._per_session_stats: List[Dict[str, Any]] = [] self._per_session_stats: list[dict[str, T.Any]] = []
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def get_summary_stats(self) -> List[dict]: def get_summary_stats(self) -> list[dict]:
""" Compile the individual session statistics and calculate the total. """ Compile the individual session statistics and calculate the total.
Format the stats for display Format the stats for display
@ -336,14 +336,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0, sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0,
"end_time": np.max(timestamps) if np.any(timestamps) else 0, "end_time": np.max(timestamps) if np.any(timestamps) else 0,
"iterations": timestamps.shape[0] if np.any(timestamps) else 0} "iterations": timestamps.shape[0] if np.any(timestamps) else 0}
for sess_id, timestamps in cast(Dict[int, np.ndarray], for sess_id, timestamps in T.cast(dict[int, np.ndarray],
self._session.get_timestamps(None)).items()} self._session.get_timestamps(None)).items()}
elif _SESSION.is_training: elif _SESSION.is_training:
logger.debug("Updating summary time stamps for training session") logger.debug("Updating summary time stamps for training session")
session_id = _SESSION.session_ids[-1] session_id = _SESSION.session_ids[-1]
latest = cast(np.ndarray, self._session.get_timestamps(session_id)) latest = T.cast(np.ndarray, self._session.get_timestamps(session_id))
self._time_stats[session_id] = { self._time_stats[session_id] = {
"start_time": np.min(latest) if np.any(latest) else 0, "start_time": np.min(latest) if np.any(latest) else 0,
@ -392,7 +392,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
/ stats["elapsed"] if stats["elapsed"] > 0 else 0) / stats["elapsed"] if stats["elapsed"] > 0 else 0)
logger.debug("per_session_stats: %s", self._per_session_stats) logger.debug("per_session_stats: %s", self._per_session_stats)
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]: def _collate_stats(self, session_id: int) -> dict[str, int | float]:
""" Collate the session summary statistics for the given session ID. """ Collate the session summary statistics for the given session ID.
Parameters Parameters
@ -422,7 +422,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
logger.debug(retval) logger.debug(retval)
return retval return retval
def _total_stats(self) -> Dict[str, Union[str, int, float]]: def _total_stats(self) -> dict[str, str | int | float]:
""" Compile the Totals stats. """ Compile the Totals stats.
Totals are fully calculated each time as they will change on the basis of the training Totals are fully calculated each time as they will change on the basis of the training
session. session.
@ -459,7 +459,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
logger.debug(totals) logger.debug(totals)
return totals return totals
def _format_stats(self, compiled_stats: List[dict]) -> List[dict]: def _format_stats(self, compiled_stats: list[dict]) -> list[dict]:
""" Format for the incoming list of statistics for display. """ Format for the incoming list of statistics for display.
Parameters Parameters
@ -489,7 +489,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
return retval return retval
@classmethod @classmethod
def _convert_time(cls, timestamp: float) -> Tuple[str, str, str]: def _convert_time(cls, timestamp: float) -> tuple[str, str, str]:
""" Convert time stamp to total hours, minutes and seconds. """ Convert time stamp to total hours, minutes and seconds.
Parameters Parameters
@ -534,8 +534,8 @@ class Calculations():
""" """
def __init__(self, session_id, def __init__(self, session_id,
display: str = "loss", display: str = "loss",
loss_keys: Union[List[str], str] = "loss", loss_keys: list[str] | str = "loss",
selections: Union[List[str], str] = "raw", selections: list[str] | str = "raw",
avg_samples: int = 500, avg_samples: int = 500,
smooth_amount: float = 0.90, smooth_amount: float = 0.90,
flatten_outliers: bool = False) -> None: flatten_outliers: bool = False) -> None:
@ -552,13 +552,13 @@ class Calculations():
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys] self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
self._selections = selections if isinstance(selections, list) else [selections] self._selections = selections if isinstance(selections, list) else [selections]
self._is_totals = session_id is None self._is_totals = session_id is None
self._args: Dict[str, Union[int, float]] = {"avg_samples": avg_samples, self._args: dict[str, int | float] = {"avg_samples": avg_samples,
"smooth_amount": smooth_amount, "smooth_amount": smooth_amount,
"flatten_outliers": flatten_outliers} "flatten_outliers": flatten_outliers}
self._iterations = 0 self._iterations = 0
self._limit = 0 self._limit = 0
self._start_iteration = 0 self._start_iteration = 0
self._stats: Dict[str, np.ndarray] = {} self._stats: dict[str, np.ndarray] = {}
self.refresh() self.refresh()
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -573,11 +573,11 @@ class Calculations():
return self._start_iteration return self._start_iteration
@property @property
def stats(self) -> Dict[str, np.ndarray]: def stats(self) -> dict[str, np.ndarray]:
""" dict: The final calculated statistics """ """ dict: The final calculated statistics """
return self._stats return self._stats
def refresh(self) -> Optional["Calculations"]: def refresh(self) -> Calculations | None:
""" Refresh the stats """ """ Refresh the stats """
logger.debug("Refreshing") logger.debug("Refreshing")
if not _SESSION.is_loaded: if not _SESSION.is_loaded:
@ -736,7 +736,8 @@ class Calculations():
""" """
logger.debug("Calculating rate") logger.debug("Calculating rate")
batch_size = _SESSION.batch_sizes[self._session_id] * 2 batch_size = _SESSION.batch_sizes[self._session_id] * 2
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id))) retval = batch_size / np.diff(T.cast(np.ndarray,
_SESSION.get_timestamps(self._session_id)))
logger.debug("Calculated rate: Item_count: %s", len(retval)) logger.debug("Calculated rate: Item_count: %s", len(retval))
return retval return retval
@ -757,7 +758,7 @@ class Calculations():
logger.debug("Calculating totals rate") logger.debug("Calculating totals rate")
batchsizes = _SESSION.batch_sizes batchsizes = _SESSION.batch_sizes
total_timestamps = _SESSION.get_timestamps(None) total_timestamps = _SESSION.get_timestamps(None)
rate: List[float] = [] rate: list[float] = []
for sess_id in sorted(total_timestamps.keys()): for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id] batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id] timestamps = total_timestamps[sess_id]
@ -797,7 +798,7 @@ class Calculations():
The moving average for the given data The moving average for the given data
""" """
logger.debug("Calculating Average. Data points: %s", len(data)) logger.debug("Calculating Average. Data points: %s", len(data))
window = cast(int, self._args["avg_samples"]) window = T.cast(int, self._args["avg_samples"])
pad = ceil(window / 2) pad = ceil(window / 2)
datapoints = data.shape[0] datapoints = data.shape[0]
@ -953,7 +954,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
def _ewma_vectorized(self, def _ewma_vectorized(self,
data: np.ndarray, data: np.ndarray,
out: np.ndarray, out: np.ndarray,
offset: Optional[float] = None) -> None: offset: float | None = None) -> None:
""" Calculates the exponential moving average over a vector. Will fail for large inputs. """ Calculates the exponential moving average over a vector. Will fail for large inputs.
The result is processed in place into the array passed to the `out` parameter The result is processed in place into the array passed to the `out` parameter

View file

@ -5,10 +5,10 @@ import logging
import re import re
import tkinter as tk import tkinter as tk
import typing as T
from tkinter import colorchooser, ttk from tkinter import colorchooser, ttk
from itertools import zip_longest from itertools import zip_longest
from functools import partial from functools import partial
from typing import Any, Dict
from _tkinter import Tcl_Obj, TclError from _tkinter import Tcl_Obj, TclError
@ -24,7 +24,9 @@ _ = _LANG.gettext
# We store Tooltips, ContextMenus and Commands globally when they are created # We store Tooltips, ContextMenus and Commands globally when they are created
# Because we need to add them back to newly cloned widgets (they are not easily accessible from # Because we need to add them back to newly cloned widgets (they are not easily accessible from
# original config or are prone to getting destroyed when the original widget is destroyed) # original config or are prone to getting destroyed when the original widget is destroyed)
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={}) _RECREATE_OBJECTS: dict[str, dict[str, T.Any]] = {"tooltips": {},
"commands": {},
"contextmenus": {}}
def _get_tooltip(widget, text=None, text_variable=None): def _get_tooltip(widget, text=None, text_variable=None):
@ -154,17 +156,17 @@ class ControlPanelOption():
self.dtype = dtype self.dtype = dtype
self.sysbrowser = sysbrowser self.sysbrowser = sysbrowser
self._command = command self._command = command
self._options = dict(title=title, self._options = {"title": title,
subgroup=subgroup, "subgroup": subgroup,
group=group, "group": group,
default=default, "default": default,
initial_value=initial_value, "initial_value": initial_value,
choices=choices, "choices": choices,
is_radio=is_radio, "is_radio": is_radio,
is_multi_option=is_multi_option, "is_multi_option": is_multi_option,
rounding=rounding, "rounding": rounding,
min_max=min_max, "min_max": min_max,
helptext=helptext) "helptext": helptext}
self.control = self.get_control() self.control = self.get_control()
self.tk_var = self.get_tk_var(initial_value, track_modified) self.tk_var = self.get_tk_var(initial_value, track_modified)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -421,7 +423,7 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
self.group_frames = {} self.group_frames = {}
self._sub_group_frames = {} self._sub_group_frames = {}
canvas_kwargs = dict(bd=0, highlightthickness=0, bg=self._theme["panel_background"]) canvas_kwargs = {"bd": 0, "highlightthickness": 0, "bg": self._theme["panel_background"]}
self._canvas = tk.Canvas(self, **canvas_kwargs) self._canvas = tk.Canvas(self, **canvas_kwargs)
self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
@ -525,8 +527,8 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW) group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW)
self.group_frames[group] = dict(frame=retval, self.group_frames[group] = {"frame": retval,
chkbtns=self.checkbuttons_frame(retval)) "chkbtns": self.checkbuttons_frame(retval)}
group_frame = self.group_frames[group] group_frame = self.group_frames[group]
return group_frame return group_frame
@ -720,12 +722,12 @@ class AutoFillContainer():
""" """
retval = {} retval = {}
if widget.__class__.__name__ == "MultiOption": if widget.__class__.__name__ == "MultiOption":
retval = dict(value=widget._value, # pylint:disable=protected-access retval = {"value": widget._value, # pylint:disable=protected-access
variable=widget._master_variable) # pylint:disable=protected-access "variable": widget._master_variable} # pylint:disable=protected-access
elif widget.__class__.__name__ == "ToggledFrame": elif widget.__class__.__name__ == "ToggledFrame":
# Toggled Frames need to have their variable tracked # Toggled Frames need to have their variable tracked
retval = dict(text=widget._text, # pylint:disable=protected-access retval = {"text": widget._text, # pylint:disable=protected-access
toggle_var=widget._toggle_var) # pylint:disable=protected-access "toggle_var": widget._toggle_var} # pylint:disable=protected-access
return retval return retval
def get_all_children_config(self, widget, child_list): def get_all_children_config(self, widget, child_list):
@ -988,7 +990,7 @@ class ControlBuilder():
if self.option.control != ttk.Checkbutton: if self.option.control != ttk.Checkbutton:
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
if self.option.helptext is not None and not self.helpset: if self.option.helptext is not None and not self.helpset:
tooltip_kwargs = dict(text=self.option.helptext) tooltip_kwargs = {"text": self.option.helptext}
if self.option.sysbrowser is not None: if self.option.sysbrowser is not None:
tooltip_kwargs["text_variable"] = self.option.tk_var tooltip_kwargs["text_variable"] = self.option.tk_var
_get_tooltip(ctl, **tooltip_kwargs) _get_tooltip(ctl, **tooltip_kwargs)
@ -1071,7 +1073,7 @@ class ControlBuilder():
"rounding: %s, min_max: %s)", self.option.name, self.option.dtype, "rounding: %s, min_max: %s)", self.option.name, self.option.dtype,
self.option.rounding, self.option.min_max) self.option.rounding, self.option.min_max)
validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float
vcmd = (self.frame.register(validate)) vcmd = self.frame.register(validate)
tbox = tk.Entry(self.frame, tbox = tk.Entry(self.frame,
width=8, width=8,
textvariable=self.option.tk_var, textvariable=self.option.tk_var,
@ -1246,15 +1248,15 @@ class FileBrowser():
@property @property
def helptext(self): def helptext(self):
""" Dict containing tooltip text for buttons """ """ Dict containing tooltip text for buttons """
retval = dict(folder=_("Select a folder..."), retval = {"folder": _("Select a folder..."),
load=_("Select a file..."), "load": _("Select a file..."),
load2=_("Select a file..."), "load2": _("Select a file..."),
picture=_("Select a folder of images..."), "picture": _("Select a folder of images..."),
video=_("Select a video..."), "video": _("Select a video..."),
model=_("Select a model folder..."), "model": _("Select a model folder..."),
multi_load=_("Select one or more files..."), "multi_load": _("Select one or more files..."),
context=_("Select a file or folder..."), "context": _("Select a file or folder..."),
save_as=_("Select a save location...")) "save_as": _("Select a save location...")}
return retval return retval
@staticmethod @staticmethod

View file

@ -4,11 +4,10 @@ import datetime
import gettext import gettext
import logging import logging
import os import os
import sys
import tkinter as tk import tkinter as tk
import typing as T
from tkinter import ttk from tkinter import ttk
from typing import Dict, Optional, Tuple
from lib.training.preview_tk import PreviewTk from lib.training.preview_tk import PreviewTk
@ -19,11 +18,6 @@ from .analysis import Calculations, Session
from .control_helper import set_slider_rounding from .control_helper import set_slider_rounding
from .utils import FileHandler, get_config, get_images, preview_trigger from .utils import FileHandler, get_config, get_images, preview_trigger
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# LOCALES # LOCALES
@ -92,7 +86,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
logger.debug("Initializing %s (args: %s, kwargs: %s)", logger.debug("Initializing %s (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs) self.__class__.__name__, args, kwargs)
self._preview = get_images().preview_train self._preview = get_images().preview_train
self._display: Optional[PreviewTk] = None self._display: PreviewTk | None = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -177,9 +171,9 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
tab_name: str, tab_name: str,
helptext: str, helptext: str,
wait_time: int, wait_time: int,
command: Optional[str] = None) -> None: command: str | None = None) -> None:
self._trace_vars: Dict[Literal["smoothgraph", "display_iterations"], self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
Tuple[tk.BooleanVar, str]] = {} tuple[tk.BooleanVar, str]] = {}
super().__init__(parent, tab_name, helptext, wait_time, command) super().__init__(parent, tab_name, helptext, wait_time, command)
def set_vars(self) -> None: def set_vars(self) -> None:
@ -446,7 +440,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
def _add_trace_variables(self) -> None: def _add_trace_variables(self) -> None:
""" Add tracing for when the option sliders are updated, for updating the graph. """ """ Add tracing for when the option sliders are updated, for updating the graph. """
for name, action in zip(get_args(Literal["smoothgraph", "display_iterations"]), for name, action in zip(T.get_args(T.Literal["smoothgraph", "display_iterations"]),
(self._smooth_amount_callback, self._iteration_limit_callback)): (self._smooth_amount_callback, self._iteration_limit_callback)):
var = self.vars[name] var = self.vars[name]
if name not in self._trace_vars: if name not in self._trace_vars:

View file

@ -1,12 +1,13 @@
#!/usr/bin python3 #!/usr/bin python3
""" Graph functions for Display Frame area of the Faceswap GUI """ """ Graph functions for Display Frame area of the Faceswap GUI """
from __future__ import annotations
import datetime import datetime
import logging import logging
import os import os
import tkinter as tk import tkinter as tk
import typing as T
from tkinter import ttk from tkinter import ttk
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
from math import ceil, floor from math import ceil, floor
import numpy as np import numpy as np
@ -20,7 +21,7 @@ from matplotlib.backend_bases import NavigationToolbar2
from .custom_widgets import Tooltip from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask from .utils import get_config, get_images, LongRunningTask
if TYPE_CHECKING: if T.TYPE_CHECKING:
from matplotlib.lines import Line2D from matplotlib.lines import Line2D
matplotlib.use("TkAgg") matplotlib.use("TkAgg")
@ -49,8 +50,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._ylabel = ylabel self._ylabel = ylabel
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper", self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"] "summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
self._lines: List["Line2D"] = [] self._lines: list[Line2D] = []
self._toolbar: Optional["NavigationToolbar"] = None self._toolbar: "NavigationToolbar" | None = None
self._fig = Figure(figsize=(4, 4), dpi=75) self._fig = Figure(figsize=(4, 4), dpi=75)
self._ax1 = self._fig.add_subplot(1, 1, 1) self._ax1 = self._fig.add_subplot(1, 1, 1)
@ -129,7 +130,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._ax1.set_ylim(0.00, 100.0) self._ax1.set_ylim(0.00, 100.0)
self._ax1.set_xlim(0, 1) self._ax1.set_xlim(0, 1)
def _axes_limits_set(self, data: List[float]) -> None: def _axes_limits_set(self, data: list[float]) -> None:
""" Set the axes limits. """ Set the axes limits.
Parameters Parameters
@ -154,7 +155,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._axes_limits_set_default() self._axes_limits_set_default()
@staticmethod @staticmethod
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]: def _axes_data_get_min_max(data: list[float]) -> tuple[float, float]:
""" Obtain the minimum and maximum values for the y-axis from the given data points. """ Obtain the minimum and maximum values for the y-axis from the given data points.
Parameters Parameters
@ -188,7 +189,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
logger.debug("yscale: '%s'", scale) logger.debug("yscale: '%s'", scale)
self._ax1.set_yscale(scale) self._ax1.set_yscale(scale)
def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]: def _lines_sort(self, keys: list[str]) -> list[list[str | int | tuple[float]]]:
""" Sort the data keys into consistent order and set line color map and line width. """ Sort the data keys into consistent order and set line color map and line width.
Parameters Parameters
@ -202,8 +203,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
A list of loss keys with their corresponding line formatting and color information A list of loss keys with their corresponding line formatting and color information
""" """
logger.trace("Sorting lines") # type:ignore[attr-defined] logger.trace("Sorting lines") # type:ignore[attr-defined]
raw_lines: List[List[str]] = [] raw_lines: list[list[str]] = []
sorted_lines: List[List[str]] = [] sorted_lines: list[list[str]] = []
for key in sorted(keys): for key in sorted(keys):
title = key.replace("_", " ").title() title = key.replace("_", " ").title()
if key.startswith("raw"): if key.startswith("raw"):
@ -217,7 +218,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return lines return lines
@staticmethod @staticmethod
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int: def _lines_groupsize(raw_lines: list[list[str]], sorted_lines: list[list[str]]) -> int:
""" Get the number of items in each group. """ Get the number of items in each group.
If raw data isn't selected, then check the length of remaining groups until something is If raw data isn't selected, then check the length of remaining groups until something is
@ -246,8 +247,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
return groupsize return groupsize
def _lines_style(self, def _lines_style(self,
lines: List[List[str]], lines: list[list[str]],
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]: groupsize: int) -> list[list[str | int | tuple[float]]]:
""" Obtain the color map and line width for each group. """ Obtain the color map and line width for each group.
Parameters Parameters
@ -266,13 +267,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
groups = int(len(lines) / groupsize) groups = int(len(lines) / groupsize)
colours = self._lines_create_colors(groupsize, groups) colours = self._lines_create_colors(groupsize, groups)
widths = list(range(1, groups + 1)) widths = list(range(1, groups + 1))
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines) retval = T.cast(list[list[str | int | tuple[float]]], lines)
for idx, item in enumerate(retval): for idx, item in enumerate(retval):
linewidth = widths[idx // groupsize] linewidth = widths[idx // groupsize]
item.extend((linewidth, colours[idx])) item.extend((linewidth, colours[idx]))
return retval return retval
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]: def _lines_create_colors(self, groupsize: int, groups: int) -> list[tuple[float]]:
""" Create the color maps. """ Create the color maps.
Parameters Parameters
@ -336,8 +337,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
super().__init__(parent, data, ylabel) super().__init__(parent, data, ylabel)
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask self._thread: LongRunningTask | None = None # Thread for LongRunningTask
self._displayed_keys: List[str] = [] self._displayed_keys: list[str] = []
self._add_callback() self._add_callback()
def _add_callback(self) -> None: def _add_callback(self) -> None:
@ -352,7 +353,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
def refresh(self, *args) -> None: # pylint: disable=unused-argument def refresh(self, *args) -> None: # pylint: disable=unused-argument
""" Read the latest loss data and apply to current graph """ """ Read the latest loss data and apply to current graph """
refresh_var = cast(tk.BooleanVar, get_config().tk_vars.refresh_graph) refresh_var = T.cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
if not refresh_var.get() and self._thread is None: if not refresh_var.get() and self._thread is None:
return return
@ -533,7 +534,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
text: str, text: str,
image_file: str, image_file: str,
toggle: bool, toggle: bool,
command) -> Union[ttk.Button, ttk.Checkbutton]: command) -> ttk.Button | ttk.Checkbutton:
""" Override the default button method to use our icons and ttk widgets for """ Override the default button method to use our icons and ttk widgets for
consistent GUI layout. consistent GUI layout.
@ -563,10 +564,10 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
img = get_images().icons[icon] img = get_images().icons[icon]
if not toggle: if not toggle:
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame, btn: ttk.Button | ttk.Checkbutton = ttk.Button(frame,
text=text, text=text,
image=img, image=img,
command=command) command=command)
else: else:
var = tk.IntVar(master=frame) var = tk.IntVar(master=frame)
btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var) btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var)

View file

@ -1,6 +1,6 @@
#!/usr/bin python3 #!/usr/bin python3
""" The Menu Bars for faceswap GUI """ """ The Menu Bars for faceswap GUI """
from __future__ import annotations
import gettext import gettext
import locale import locale
import logging import logging
@ -33,7 +33,7 @@ _ = _LANG.gettext
_WORKING_DIR = os.path.dirname(os.path.realpath(sys.argv[0])) _WORKING_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
_RESOURCES: T.List[T.Tuple[str, str]] = [ _RESOURCES: list[tuple[str, str]] = [
(_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"), (_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"),
(_("Patreon - Support this project"), "https://www.patreon.com/faceswap"), (_("Patreon - Support this project"), "https://www.patreon.com/faceswap"),
(_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"), (_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"),
@ -48,7 +48,7 @@ class MainMenuBar(tk.Menu): # pylint:disable=too-many-ancestors
master: :class:`tkinter.Tk` master: :class:`tkinter.Tk`
The root tkinter object The root tkinter object
""" """
def __init__(self, master: "FaceswapGui") -> None: def __init__(self, master: FaceswapGui) -> None:
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(master) super().__init__(master)
self.root = master self.root = master
@ -431,7 +431,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
return True return True
@classmethod @classmethod
def _get_branches(cls) -> T.Optional[str]: def _get_branches(cls) -> str | None:
""" Get the available github branches """ Get the available github branches
Returns Returns
@ -453,7 +453,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
return stdout.decode(locale.getpreferredencoding(), errors="replace") return stdout.decode(locale.getpreferredencoding(), errors="replace")
@classmethod @classmethod
def _filter_branches(cls, stdout: str) -> T.List[str]: def _filter_branches(cls, stdout: str) -> list[str]:
""" Filter the branches, remove duplicates and the current branch and return a sorted """ Filter the branches, remove duplicates and the current branch and return a sorted
list. list.
@ -548,7 +548,7 @@ class TaskBar(ttk.Frame): # pylint: disable=too-many-ancestors
self._section_separator() self._section_separator()
@classmethod @classmethod
def _loader_and_kwargs(cls, btntype: str) -> T.Tuple[str, T.Dict[str, bool]]: def _loader_and_kwargs(cls, btntype: str) -> tuple[str, dict[str, bool]]:
""" Get the loader name and key word arguments for the given button type """ Get the loader name and key word arguments for the given button type
Parameters Parameters

View file

@ -1,6 +1,6 @@
#!/usr/bin python3 #!/usr/bin python3
""" The pop-up window of the Faceswap GUI for the setting of configuration options. """ """ The pop-up window of the Faceswap GUI for the setting of configuration options. """
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from configparser import ConfigParser from configparser import ConfigParser
import gettext import gettext
@ -9,7 +9,8 @@ import os
import sys import sys
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from typing import Dict, TYPE_CHECKING import typing as T
from importlib import import_module from importlib import import_module
from lib.serializer import get_serializer from lib.serializer import get_serializer
@ -18,7 +19,7 @@ from .control_helper import ControlPanel, ControlPanelOption
from .custom_widgets import Tooltip from .custom_widgets import Tooltip
from .utils import FileHandler, get_config, get_images, PATHCACHE from .utils import FileHandler, get_config, get_images, PATHCACHE
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -124,7 +125,7 @@ class _ConfigurePlugins(tk.Toplevel):
super().__init__() super().__init__()
self._root = get_config().root self._root = get_config().root
self._set_geometry() self._set_geometry()
self._tk_vars = dict(header=tk.StringVar()) self._tk_vars = {"header": tk.StringVar()}
theme = {**get_config().user_theme["group_panel"], theme = {**get_config().user_theme["group_panel"],
**get_config().user_theme["group_settings"]} **get_config().user_theme["group_settings"]}
@ -402,7 +403,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
""" """
def __init__(self, top_level, parent, configurations, tree, theme): def __init__(self, top_level, parent, configurations, tree, theme):
super().__init__(parent) super().__init__(parent)
self._configs: Dict[str, "FaceswapConfig"] = configurations self._configs: dict[str, FaceswapConfig] = configurations
self._theme = theme self._theme = theme
self._tree = tree self._tree = tree
self._vars = {} self._vars = {}
@ -443,7 +444,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
sect = section.split(".")[-1] sect = section.split(".")[-1]
# Elevate global to root # Elevate global to root
key = plugin if sect == "global" else f"{plugin}|{category}|{sect}" key = plugin if sect == "global" else f"{plugin}|{category}|{sect}"
retval[key] = dict(helptext=None, options=OrderedDict()) retval[key] = {"helptext": None, "options": OrderedDict()}
retval[key]["helptext"] = conf.defaults[section].helptext retval[key]["helptext"] = conf.defaults[section].helptext
for option, params in conf.defaults[section].items.items(): for option, params in conf.defaults[section].items.items():
@ -632,7 +633,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
def _get_new_config(self, def _get_new_config(self,
page_only: bool, page_only: bool,
config: "FaceswapConfig", config: FaceswapConfig,
category: str, category: str,
lookup: str) -> ConfigParser: lookup: str) -> ConfigParser:
""" Obtain a new configuration file for saving """ Obtain a new configuration file for saving
@ -812,9 +813,9 @@ class _Presets():
return None return None
args = ("save_filename", "json") if action == "save" else ("filename", "json") args = ("save_filename", "json") if action == "save" else ("filename", "json")
kwargs = dict(title=f"{action.title()} Preset...", kwargs = {"title": f"{action.title()} Preset...",
initial_folder=self._preset_path, "initial_folder": self._preset_path,
parent=self._parent) "parent": self._parent}
if action == "save": if action == "save":
kwargs["initial_file"] = self._get_initial_filename() kwargs["initial_file"] = self._get_initial_filename()

View file

@ -8,7 +8,6 @@ import tkinter as tk
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tkinter import ttk from tkinter import ttk
from typing import Dict, List, Optional, Tuple, Type, Union
from .control_helper import ControlBuilder, ControlPanelOption from .control_helper import ControlBuilder, ControlPanelOption
from .custom_widgets import Tooltip from .custom_widgets import Tooltip
@ -66,7 +65,7 @@ class SessionTKVars:
outliers: tk.BooleanVar outliers: tk.BooleanVar
avgiterations: tk.IntVar avgiterations: tk.IntVar
smoothamount: tk.DoubleVar smoothamount: tk.DoubleVar
loss_keys: Dict[str, tk.BooleanVar] = field(default_factory=dict) loss_keys: dict[str, tk.BooleanVar] = field(default_factory=dict)
class SessionPopUp(tk.Toplevel): class SessionPopUp(tk.Toplevel):
@ -82,13 +81,13 @@ class SessionPopUp(tk.Toplevel):
logger.debug("Initializing: %s: (session_id: %s, data_points: %s)", logger.debug("Initializing: %s: (session_id: %s, data_points: %s)",
self.__class__.__name__, session_id, data_points) self.__class__.__name__, session_id, data_points)
super().__init__() super().__init__()
self._thread: Optional[LongRunningTask] = None # Thread for loading data in background self._thread: LongRunningTask | None = None # Thread for loading data in background
self._default_view = "avg" if data_points > 1000 else "smoothed" self._default_view = "avg" if data_points > 1000 else "smoothed"
self._session_id = None if session_id == "Total" else int(session_id) self._session_id = None if session_id == "Total" else int(session_id)
self._graph_frame = ttk.Frame(self) self._graph_frame = ttk.Frame(self)
self._graph: Optional[SessionGraph] = None self._graph: SessionGraph | None = None
self._display_data: Optional[Calculations] = None self._display_data: Calculations | None = None
self._vars = self._set_vars() self._vars = self._set_vars()
@ -172,7 +171,7 @@ class SessionPopUp(tk.Toplevel):
The frame that the options reside in The frame that the options reside in
""" """
logger.debug("Building Combo boxes") logger.debug("Building Combo boxes")
choices = dict(Display=("Loss", "Rate"), Scale=("Linear", "Log")) choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")}
for item in ["Display", "Scale"]: for item in ["Display", "Scale"]:
var: tk.StringVar = getattr(self._vars, item.lower()) var: tk.StringVar = getattr(self._vars, item.lower())
@ -273,11 +272,11 @@ class SessionPopUp(tk.Toplevel):
logger.debug("Building Slider Controls") logger.debug("Building Slider Controls")
for item in ("avgiterations", "smoothamount"): for item in ("avgiterations", "smoothamount"):
if item == "avgiterations": if item == "avgiterations":
dtype: Union[Type[int], Type[float]] = int dtype: type[int] | type[float] = int
text = "Iterations to Average:" text = "Iterations to Average:"
default: Union[int, float] = 500 default: int | float = 500
rounding = 25 rounding = 25
min_max: Tuple[int, Union[int, float]] = (25, 2500) min_max: tuple[int, int | float] = (25, 2500)
elif item == "smoothamount": elif item == "smoothamount":
dtype = float dtype = float
text = "Smoothing Amount:" text = "Smoothing Amount:"
@ -404,20 +403,20 @@ class SessionPopUp(tk.Toplevel):
str str
The help text for the given action The help text for the given action
""" """
lookup = dict( lookup = {
reload=_("Refresh graph"), "reload": _("Refresh graph"),
save=_("Save display data to csv"), "save": _("Save display data to csv"),
avgiterations=_("Number of data points to sample for rolling average"), "avgiterations": _("Number of data points to sample for rolling average"),
smoothamount=_("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum " "smoothamount": _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum "
"smoothing"), "smoothing"),
outliers=_("Flatten data points that fall more than 1 standard deviation from the " "outliers": _("Flatten data points that fall more than 1 standard deviation from the "
"mean to the mean value."), "mean to the mean value."),
avg=_("Display rolling average of the data"), "avg": _("Display rolling average of the data"),
smoothed=_("Smooth the data"), "smoothed": _("Smooth the data"),
raw=_("Display raw data"), "raw": _("Display raw data"),
trend=_("Display polynormal data trend"), "trend": _("Display polynormal data trend"),
display=_("Set the data to display"), "display": _("Set the data to display"),
scale=_("Change y-axis scale")) "scale": _("Change y-axis scale")}
return lookup.get(action.lower(), "") return lookup.get(action.lower(), "")
def _compile_display_data(self) -> bool: def _compile_display_data(self) -> bool:
@ -446,13 +445,13 @@ class SessionPopUp(tk.Toplevel):
self._lbl_loading.pack(fill=tk.BOTH, expand=True) self._lbl_loading.pack(fill=tk.BOTH, expand=True)
self.update_idletasks() self.update_idletasks()
kwargs = dict(session_id=self._session_id, kwargs = {"session_id": self._session_id,
display=self._vars.display.get(), "display": self._vars.display.get(),
loss_keys=loss_keys, "loss_keys": loss_keys,
selections=selections, "selections": selections,
avg_samples=self._vars.avgiterations.get(), "avg_samples": self._vars.avgiterations.get(),
smooth_amount=self._vars.smoothamount.get(), "smooth_amount": self._vars.smoothamount.get(),
flatten_outliers=self._vars.outliers.get()) "flatten_outliers": self._vars.outliers.get()}
self._thread = LongRunningTask(target=self._get_display_data, self._thread = LongRunningTask(target=self._get_display_data,
kwargs=kwargs, kwargs=kwargs,
widget=self) widget=self)
@ -491,7 +490,7 @@ class SessionPopUp(tk.Toplevel):
""" """
return Calculations(**kwargs) return Calculations(**kwargs)
def _check_valid_selection(self, loss_keys: List[str], selections: List[str]) -> bool: def _check_valid_selection(self, loss_keys: list[str], selections: list[str]) -> bool:
""" Check that there will be data to display. """ Check that there will be data to display.
Parameters Parameters
@ -530,7 +529,7 @@ class SessionPopUp(tk.Toplevel):
return False return False
return True return True
def _selections_to_list(self) -> List[str]: def _selections_to_list(self) -> list[str]:
""" Compile checkbox selections to a list. """ Compile checkbox selections to a list.
Returns Returns

View file

@ -1,19 +1,20 @@
#!/usr/bin python3 #!/usr/bin python3
""" Global configuration optiopns for the Faceswap GUI """ """ Global configuration optiopns for the Faceswap GUI """
from __future__ import annotations
import logging import logging
import os import os
import sys import sys
import tkinter as tk import tkinter as tk
import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING
from lib.gui._config import Config as UserConfig from lib.gui._config import Config as UserConfig
from lib.gui.project import Project, Tasks from lib.gui.project import Project, Tasks
from lib.gui.theme import Style from lib.gui.theme import Style
from .file_handler import FileHandler from .file_handler import FileHandler
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.gui.options import CliOptions from lib.gui.options import CliOptions
from lib.gui.custom_widgets import StatusBar from lib.gui.custom_widgets import StatusBar
from lib.gui.command import CommandNotebook from lib.gui.command import CommandNotebook
@ -22,12 +23,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PATHCACHE = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache") PATHCACHE = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache")
_CONFIG: Optional["Config"] = None _CONFIG: Config | None = None
def initialize_config(root: tk.Tk, def initialize_config(root: tk.Tk,
cli_opts: Optional["CliOptions"], cli_opts: CliOptions | None,
statusbar: Optional["StatusBar"]) -> Optional["Config"]: statusbar: StatusBar | None) -> Config | None:
""" Initialize the GUI Master :class:`Config` and add to global constant. """ Initialize the GUI Master :class:`Config` and add to global constant.
This should only be called once on first GUI startup. Future access to :class:`Config` This should only be called once on first GUI startup. Future access to :class:`Config`
@ -145,13 +146,13 @@ class GlobalVariables():
@dataclass @dataclass
class _GuiObjects: class _GuiObjects:
""" Data class for commonly accessed GUI Objects """ """ Data class for commonly accessed GUI Objects """
cli_opts: Optional["CliOptions"] cli_opts: CliOptions | None
tk_vars: GlobalVariables tk_vars: GlobalVariables
project: Project project: Project
tasks: Tasks tasks: Tasks
status_bar: Optional["StatusBar"] status_bar: StatusBar | None
default_options: Dict[str, Dict[str, Any]] = field(default_factory=dict) default_options: dict[str, dict[str, T.Any]] = field(default_factory=dict)
command_notebook: Optional["CommandNotebook"] = None command_notebook: CommandNotebook | None = None
class Config(): class Config():
@ -172,15 +173,15 @@ class Config():
""" """
def __init__(self, def __init__(self,
root: tk.Tk, root: tk.Tk,
cli_opts: Optional["CliOptions"], cli_opts: CliOptions | None,
statusbar: Optional["StatusBar"]) -> None: statusbar: StatusBar | None) -> None:
logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)", logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)",
self.__class__.__name__, root, cli_opts, statusbar) self.__class__.__name__, root, cli_opts, statusbar)
self._default_font = cast(dict, tk.font.nametofont("TkDefaultFont").configure())["family"] self._default_font = T.cast(dict,
self._constants = dict( tk.font.nametofont("TkDefaultFont").configure())["family"]
root=root, self._constants = {"root": root,
scaling_factor=self._get_scaling(root), "scaling_factor": self._get_scaling(root),
default_font=self._default_font) "default_font": self._default_font}
self._gui_objects = _GuiObjects( self._gui_objects = _GuiObjects(
cli_opts=cli_opts, cli_opts=cli_opts,
tk_vars=GlobalVariables(), tk_vars=GlobalVariables(),
@ -211,7 +212,7 @@ class Config():
# GUI Objects # GUI Objects
@property @property
def cli_opts(self) -> "CliOptions": def cli_opts(self) -> CliOptions:
""" :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """ """ :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """
# This should only be None when a separate tool (not main GUI) is used, at which point # This should only be None when a separate tool (not main GUI) is used, at which point
# cli_opts do not exist # cli_opts do not exist
@ -234,12 +235,12 @@ class Config():
return self._gui_objects.tasks return self._gui_objects.tasks
@property @property
def default_options(self) -> Dict[str, Dict[str, Any]]: def default_options(self) -> dict[str, dict[str, T.Any]]:
""" dict: The default options for all tabs """ """ dict: The default options for all tabs """
return self._gui_objects.default_options return self._gui_objects.default_options
@property @property
def statusbar(self) -> "StatusBar": def statusbar(self) -> StatusBar:
""" :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar """ :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar
:class:`tkinter.ttk.Frame`. """ :class:`tkinter.ttk.Frame`. """
# This should only be None when a separate tool (not main GUI) is used, at which point # This should only be None when a separate tool (not main GUI) is used, at which point
@ -248,31 +249,31 @@ class Config():
return self._gui_objects.status_bar return self._gui_objects.status_bar
@property @property
def command_notebook(self) -> Optional["CommandNotebook"]: def command_notebook(self) -> CommandNotebook | None:
""" :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """ """ :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """
return self._gui_objects.command_notebook return self._gui_objects.command_notebook
# Convenience GUI Objects # Convenience GUI Objects
@property @property
def tools_notebook(self) -> "ToolsNotebook": def tools_notebook(self) -> ToolsNotebook:
""" :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """ """ :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """
assert self.command_notebook is not None assert self.command_notebook is not None
return self.command_notebook.tools_notebook return self.command_notebook.tools_notebook
@property @property
def modified_vars(self) -> Dict[str, "tk.BooleanVar"]: def modified_vars(self) -> dict[str, tk.BooleanVar]:
""" dict: The command notebook modified tkinter variables. """ """ dict: The command notebook modified tkinter variables. """
assert self.command_notebook is not None assert self.command_notebook is not None
return self.command_notebook.modified_vars return self.command_notebook.modified_vars
@property @property
def _command_tabs(self) -> Dict[str, int]: def _command_tabs(self) -> dict[str, int]:
""" dict: Command tab titles with their IDs. """ """ dict: Command tab titles with their IDs. """
assert self.command_notebook is not None assert self.command_notebook is not None
return self.command_notebook.tab_names return self.command_notebook.tab_names
@property @property
def _tools_tabs(self) -> Dict[str, int]: def _tools_tabs(self) -> dict[str, int]:
""" dict: Tools command tab titles with their IDs. """ """ dict: Tools command tab titles with their IDs. """
assert self.command_notebook is not None assert self.command_notebook is not None
return self.command_notebook.tools_tab_names return self.command_notebook.tools_tab_names
@ -284,17 +285,17 @@ class Config():
return self._user_config return self._user_config
@property @property
def user_config_dict(self) -> Dict[str, Any]: # TODO Dataclass def user_config_dict(self) -> dict[str, T.Any]: # TODO Dataclass
""" dict: The GUI config in dict form. """ """ dict: The GUI config in dict form. """
return self._user_config.config_dict return self._user_config.config_dict
@property @property
def user_theme(self) -> Dict[str, Any]: # TODO Dataclass def user_theme(self) -> dict[str, T.Any]: # TODO Dataclass
""" dict: The GUI theme selection options. """ """ dict: The GUI theme selection options. """
return self._user_theme return self._user_theme
@property @property
def default_font(self) -> Tuple[str, int]: def default_font(self) -> tuple[str, int]:
""" tuple: The selected font as configured in user settings. First item is the font (`str`) """ tuple: The selected font as configured in user settings. First item is the font (`str`)
second item the font size (`int`). """ second item the font size (`int`). """
font = self.user_config_dict["font"] font = self.user_config_dict["font"]
@ -328,7 +329,7 @@ class Config():
self._gui_objects.default_options = default self._gui_objects.default_options = default
self.project.set_default_options() self.project.set_default_options()
def set_command_notebook(self, notebook: "CommandNotebook") -> None: def set_command_notebook(self, notebook: CommandNotebook) -> None:
""" Set the command notebook to the :attr:`command_notebook` attribute """ Set the command notebook to the :attr:`command_notebook` attribute
and enable the modified callback for :attr:`project`. and enable the modified callback for :attr:`project`.
@ -385,7 +386,7 @@ class Config():
""" Reload the user config from file. """ """ Reload the user config from file. """
self._user_config = UserConfig(None) self._user_config = UserConfig(None)
def set_cursor_busy(self, widget: Optional[tk.Widget] = None) -> None: def set_cursor_busy(self, widget: tk.Widget | None = None) -> None:
""" Set the root or widget cursor to busy. """ Set the root or widget cursor to busy.
Parameters Parameters
@ -399,7 +400,7 @@ class Config():
component.config(cursor="watch") # type: ignore component.config(cursor="watch") # type: ignore
component.update_idletasks() component.update_idletasks()
def set_cursor_default(self, widget: Optional[tk.Widget] = None) -> None: def set_cursor_default(self, widget: tk.Widget | None = None) -> None:
""" Set the root or widget cursor to default. """ Set the root or widget cursor to default.
Parameters Parameters
@ -413,7 +414,7 @@ class Config():
component.config(cursor="") # type: ignore component.config(cursor="") # type: ignore
component.update_idletasks() component.update_idletasks()
def set_root_title(self, text: Optional[str] = None) -> None: def set_root_title(self, text: str | None = None) -> None:
""" Set the main title text for Faceswap. """ Set the main title text for Faceswap.
The title will always begin with 'Faceswap.py'. Additional text can be appended. The title will always begin with 'Faceswap.py'. Additional text can be appended.

View file

@ -2,24 +2,15 @@
""" File browser utility functions for the Faceswap GUI. """ """ File browser utility functions for the Faceswap GUI. """
import logging import logging
import platform import platform
import sys
import tkinter as tk import tkinter as tk
from tkinter import filedialog from tkinter import filedialog
import typing as T
from typing import cast, Dict, IO, List, Optional, Tuple, Union
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_FILETYPE = T.Literal["default", "alignments", "config_project", "config_task",
_FILETYPE = Literal["default", "alignments", "config_project", "config_task", "config_all", "csv", "image", "ini", "state", "log", "video"]
"config_all", "csv", "image", "ini", "state", "log", "video"] _HANDLETYPE = T.Literal["open", "save", "filename", "filename_multi", "save_filename",
_HANDLETYPE = Literal["open", "save", "filename", "filename_multi", "save_filename", "context", "dir"]
"context", "dir"]
class FileHandler(): # pylint:disable=too-few-public-methods class FileHandler(): # pylint:disable=too-few-public-methods
@ -72,14 +63,14 @@ class FileHandler(): # pylint:disable=too-few-public-methods
def __init__(self, def __init__(self,
handle_type: _HANDLETYPE, handle_type: _HANDLETYPE,
file_type: Optional[_FILETYPE], file_type: _FILETYPE | None,
title: Optional[str] = None, title: str | None = None,
initial_folder: Optional[str] = None, initial_folder: str | None = None,
initial_file: Optional[str] = None, initial_file: str | None = None,
command: Optional[str] = None, command: str | None = None,
action: Optional[str] = None, action: str | None = None,
variable: Optional[str] = None, variable: str | None = None,
parent: Optional[tk.Frame] = None) -> None: parent: tk.Frame | None = None) -> None:
logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', " logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', "
"initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', " "initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', "
"variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type, "variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type,
@ -101,35 +92,35 @@ class FileHandler(): # pylint:disable=too-few-public-methods
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@property @property
def _filetypes(self) -> Dict[str, List[Tuple[str, str]]]: def _filetypes(self) -> dict[str, list[tuple[str, str]]]:
""" dict: The accepted extensions for each file type for opening/saving """ """ dict: The accepted extensions for each file type for opening/saving """
all_files = ("All files", "*.*") all_files = ("All files", "*.*")
filetypes = dict( filetypes = {
default=[all_files], "default": [all_files],
alignments=[("Faceswap Alignments", "*.fsa"), all_files], "alignments": [("Faceswap Alignments", "*.fsa"), all_files],
config_project=[("Faceswap Project files", "*.fsw"), all_files], "config_project": [("Faceswap Project files", "*.fsw"), all_files],
config_task=[("Faceswap Task files", "*.fst"), all_files], "config_task": [("Faceswap Task files", "*.fst"), all_files],
config_all=[("Faceswap Project and Task files", "*.fst *.fsw"), all_files], "config_all": [("Faceswap Project and Task files", "*.fst *.fsw"), all_files],
csv=[("Comma separated values", "*.csv"), all_files], "csv": [("Comma separated values", "*.csv"), all_files],
image=[("Bitmap", "*.bmp"), "image": [("Bitmap", "*.bmp"),
("JPG", "*.jpeg *.jpg"), ("JPG", "*.jpeg *.jpg"),
("PNG", "*.png"), ("PNG", "*.png"),
("TIFF", "*.tif *.tiff"), ("TIFF", "*.tif *.tiff"),
all_files], all_files],
ini=[("Faceswap config files", "*.ini"), all_files], "ini": [("Faceswap config files", "*.ini"), all_files],
json=[("JSON file", "*.json"), all_files], "json": [("JSON file", "*.json"), all_files],
model=[("Keras model files", "*.h5"), all_files], "model": [("Keras model files", "*.h5"), all_files],
state=[("State files", "*.json"), all_files], "state": [("State files", "*.json"), all_files],
log=[("Log files", "*.log"), all_files], "log": [("Log files", "*.log"), all_files],
video=[("Audio Video Interleave", "*.avi"), "video": [("Audio Video Interleave", "*.avi"),
("Flash Video", "*.flv"), ("Flash Video", "*.flv"),
("Matroska", "*.mkv"), ("Matroska", "*.mkv"),
("MOV", "*.mov"), ("MOV", "*.mov"),
("MP4", "*.mp4"), ("MP4", "*.mp4"),
("MPEG", "*.mpeg *.mpg *.ts *.vob"), ("MPEG", "*.mpeg *.mpg *.ts *.vob"),
("WebM", "*.webm"), ("WebM", "*.webm"),
("Windows Media Video", "*.wmv"), ("Windows Media Video", "*.wmv"),
all_files]) all_files]}
# Add in multi-select options and upper case extensions for Linux # Add in multi-select options and upper case extensions for Linux
for key in filetypes: for key in filetypes:
@ -142,32 +133,32 @@ class FileHandler(): # pylint:disable=too-few-public-methods
multi = [f"{key.title()} Files"] multi = [f"{key.title()} Files"]
multi.append(" ".join([ftype[1] multi.append(" ".join([ftype[1]
for ftype in filetypes[key] if ftype[0] != "All files"])) for ftype in filetypes[key] if ftype[0] != "All files"]))
filetypes[key].insert(0, cast(Tuple[str, str], tuple(multi))) filetypes[key].insert(0, T.cast(tuple[str, str], tuple(multi)))
return filetypes return filetypes
@property @property
def _contexts(self) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]: def _contexts(self) -> dict[str, dict[str, str | dict[str, str]]]:
"""dict: Mapping of commands, actions and their corresponding file dialog for context """dict: Mapping of commands, actions and their corresponding file dialog for context
handle types. """ handle types. """
return dict(effmpeg=dict(input={"extract": "filename", return {"effmpeg": {"input": {"extract": "filename",
"gen-vid": "dir", "gen-vid": "dir",
"get-fps": "filename", "get-fps": "filename",
"get-info": "filename", "get-info": "filename",
"mux-audio": "filename", "mux-audio": "filename",
"rescale": "filename", "rescale": "filename",
"rotate": "filename", "rotate": "filename",
"slice": "filename"}, "slice": "filename"},
output={"extract": "dir", "output": {"extract": "dir",
"gen-vid": "save_filename", "gen-vid": "save_filename",
"get-fps": "nothing", "get-fps": "nothing",
"get-info": "nothing", "get-info": "nothing",
"mux-audio": "save_filename", "mux-audio": "save_filename",
"rescale": "save_filename", "rescale": "save_filename",
"rotate": "save_filename", "rotate": "save_filename",
"slice": "save_filename"})) "slice": "save_filename"}}}
@classmethod @classmethod
def _set_dummy_master(cls) -> Optional[tk.Frame]: def _set_dummy_master(cls) -> tk.Frame | None:
""" Add an option to force black font on Linux file dialogs KDE issue that displays light """ Add an option to force black font on Linux file dialogs KDE issue that displays light
font on white background). font on white background).
@ -183,7 +174,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
if platform.system().lower() == "linux": if platform.system().lower() == "linux":
frame = tk.Frame() frame = tk.Frame()
frame.option_add("*foreground", "black") frame.option_add("*foreground", "black")
retval: Optional[tk.Frame] = frame retval: tk.Frame | None = frame
else: else:
retval = None retval = None
return retval return retval
@ -196,7 +187,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
del self._dummy_master del self._dummy_master
self._dummy_master = None self._dummy_master = None
def _set_defaults(self) -> Dict[str, Optional[str]]: def _set_defaults(self) -> dict[str, str | None]:
""" Set the default file type for the file dialog. Generally the first found file type """ Set the default file type for the file dialog. Generally the first found file type
will be used, but this is overridden if it is not appropriate. will be used, but this is overridden if it is not appropriate.
@ -205,7 +196,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
dict: dict:
The default file extension for each file type The default file extension for each file type
""" """
defaults: Dict[str, Optional[str]] = { defaults: dict[str, str | None] = {
key: next(ext for ext in val[0][1].split(" ")).replace("*", "") key: next(ext for ext in val[0][1].split(" ")).replace("*", "")
for key, val in self._filetypes.items()} for key, val in self._filetypes.items()}
defaults["default"] = None defaults["default"] = None
@ -215,15 +206,15 @@ class FileHandler(): # pylint:disable=too-few-public-methods
return defaults return defaults
def _set_kwargs(self, def _set_kwargs(self,
title: Optional[str], title: str | None,
initial_folder: Optional[str], initial_folder: str | None,
initial_file: Optional[str], initial_file: str | None,
file_type: Optional[_FILETYPE], file_type: _FILETYPE | None,
command: Optional[str], command: str | None,
action: Optional[str], action: str | None,
variable: Optional[str], variable: str | None,
parent: Optional[tk.Frame] parent: tk.Frame | None
) -> Dict[str, Union[None, tk.Frame, str, List[Tuple[str, str]]]]: ) -> dict[str, None | tk.Frame | str | list[tuple[str, str]]]:
""" Generate the required kwargs for the requested file dialog browser. """ Generate the required kwargs for the requested file dialog browser.
Parameters Parameters
@ -259,8 +250,8 @@ class FileHandler(): # pylint:disable=too-few-public-methods
title, initial_folder, initial_file, file_type, command, action, variable, title, initial_folder, initial_file, file_type, command, action, variable,
parent) parent)
kwargs: Dict[str, Union[None, tk.Frame, str, kwargs: dict[str, None | tk.Frame | str | list[tuple[str, str]]] = {
List[Tuple[str, str]]]] = dict(master=self._dummy_master) "master": self._dummy_master}
if self._handletype.lower() == "context": if self._handletype.lower() == "context":
assert command is not None and action is not None and variable is not None assert command is not None and action is not None and variable is not None
@ -304,20 +295,20 @@ class FileHandler(): # pylint:disable=too-few-public-methods
The variable associated with this file dialog The variable associated with this file dialog
""" """
if self._contexts[command].get(variable, None) is not None: if self._contexts[command].get(variable, None) is not None:
handletype = cast(Dict[str, Dict[str, Dict[str, str]]], handletype = T.cast(dict[str, dict[str, dict[str, str]]],
self._contexts)[command][variable][action] self._contexts)[command][variable][action]
else: else:
handletype = cast(Dict[str, Dict[str, str]], handletype = T.cast(dict[str, dict[str, str]],
self._contexts)[command][action] self._contexts)[command][action]
logger.debug(handletype) logger.debug(handletype)
self._handletype = cast(_HANDLETYPE, handletype) self._handletype = T.cast(_HANDLETYPE, handletype)
def _open(self) -> Optional[IO]: def _open(self) -> T.IO | None:
""" Open a file. """ """ Open a file. """
logger.debug("Popping Open browser") logger.debug("Popping Open browser")
return filedialog.askopenfile(**self._kwargs) # type: ignore return filedialog.askopenfile(**self._kwargs) # type: ignore
def _save(self) -> Optional[IO]: def _save(self) -> T.IO | None:
""" Save a file. """ """ Save a file. """
logger.debug("Popping Save browser") logger.debug("Popping Save browser")
return filedialog.asksaveasfile(**self._kwargs) # type: ignore return filedialog.asksaveasfile(**self._kwargs) # type: ignore
@ -337,7 +328,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
logger.debug("Popping Filename browser") logger.debug("Popping Filename browser")
return filedialog.askopenfilename(**self._kwargs) # type: ignore return filedialog.askopenfilename(**self._kwargs) # type: ignore
def _filename_multi(self) -> Tuple[str, ...]: def _filename_multi(self) -> tuple[str, ...]:
""" Get multiple existing file locations. """ """ Get multiple existing file locations. """
logger.debug("Popping Filename browser") logger.debug("Popping Filename browser")
return filedialog.askopenfilenames(**self._kwargs) # type: ignore return filedialog.askopenfilenames(**self._kwargs) # type: ignore

View file

@ -1,10 +1,9 @@
#!/usr/bin python3 #!/usr/bin python3
""" Utilities for handling images in the Faceswap GUI """ """ Utilities for handling images in the Faceswap GUI """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
from typing import cast, Dict, List, Optional, Sequence, Tuple
import cv2 import cv2
import numpy as np import numpy as np
@ -14,15 +13,12 @@ from lib.training.preview_cv import PreviewBuffer
from .config import get_config, PATHCACHE from .config import get_config, PATHCACHE
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal from collections.abc import Sequence
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_IMAGES: "Images" | None = None
_IMAGES: Optional["Images"] = None _PREVIEW_TRIGGER: "PreviewTrigger" | None = None
_PREVIEW_TRIGGER: Optional["PreviewTrigger"] = None
TRAININGPREVIEW = ".gui_training_preview.png" TRAININGPREVIEW = ".gui_training_preview.png"
@ -51,7 +47,7 @@ def get_images() -> "Images":
return _IMAGES return _IMAGES
def _get_previews(image_path: str) -> List[str]: def _get_previews(image_path: str) -> list[str]:
""" Get the images stored within the given directory. """ Get the images stored within the given directory.
Parameters Parameters
@ -164,12 +160,12 @@ class PreviewExtract():
self._output_path = "" self._output_path = ""
self._modified: float = 0.0 self._modified: float = 0.0
self._filenames: List[str] = [] self._filenames: list[str] = []
self._images: Optional[np.ndarray] = None self._images: np.ndarray | None = None
self._placeholder: Optional[np.ndarray] = None self._placeholder: np.ndarray | None = None
self._preview_image: Optional[Image.Image] = None self._preview_image: Image.Image | None = None
self._preview_image_tk: Optional[ImageTk.PhotoImage] = None self._preview_image_tk: ImageTk.PhotoImage | None = None
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -228,7 +224,7 @@ class PreviewExtract():
logger.debug("sorted folders: %s, return value: %s", folders, retval) logger.debug("sorted folders: %s, return value: %s", folders, retval)
return retval return retval
def _get_newest_filenames(self, image_files: List[str]) -> List[str]: def _get_newest_filenames(self, image_files: list[str]) -> list[str]:
""" Return image filenames that have been modified since the last check. """ Return image filenames that have been modified since the last check.
Parameters Parameters
@ -281,8 +277,8 @@ class PreviewExtract():
return retval return retval
def _process_samples(self, def _process_samples(self,
samples: List[np.ndarray], samples: list[np.ndarray],
filenames: List[str], filenames: list[str],
num_images: int) -> bool: num_images: int) -> bool:
""" Process the latest sample images into a displayable image. """ Process the latest sample images into a displayable image.
@ -321,8 +317,8 @@ class PreviewExtract():
return True return True
def _load_images_to_cache(self, def _load_images_to_cache(self,
image_files: List[str], image_files: list[str],
frame_dims: Tuple[int, int], frame_dims: tuple[int, int],
thumbnail_size: int) -> bool: thumbnail_size: int) -> bool:
""" Load preview images to the image cache. """ Load preview images to the image cache.
@ -349,7 +345,7 @@ class PreviewExtract():
logger.debug("num_images: %s", num_images) logger.debug("num_images: %s", num_images)
if num_images == 0: if num_images == 0:
return False return False
samples: List[np.ndarray] = [] samples: list[np.ndarray] = []
start_idx = len(image_files) - num_images if len(image_files) > num_images else 0 start_idx = len(image_files) - num_images if len(image_files) > num_images else 0
show_files = sorted(image_files, key=os.path.getctime)[start_idx:] show_files = sorted(image_files, key=os.path.getctime)[start_idx:]
dropped_files = [] dropped_files = []
@ -405,7 +401,7 @@ class PreviewExtract():
self._placeholder = placeholder self._placeholder = placeholder
logger.debug("Created placeholder. shape: %s", placeholder.shape) logger.debug("Created placeholder. shape: %s", placeholder.shape)
def _place_previews(self, frame_dims: Tuple[int, int]) -> Image.Image: def _place_previews(self, frame_dims: tuple[int, int]) -> Image.Image:
""" Format the preview thumbnails stored in the cache into a grid fitting the display """ Format the preview thumbnails stored in the cache into a grid fitting the display
panel. panel.
@ -441,12 +437,12 @@ class PreviewExtract():
placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder) placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder)
samples = np.concatenate((samples, placeholder)) samples = np.concatenate((samples, placeholder))
display = np.vstack([np.hstack(cast(Sequence, samples[row * cols: (row + 1) * cols])) display = np.vstack([np.hstack(T.cast("Sequence", samples[row * cols: (row + 1) * cols]))
for row in range(rows)]) for row in range(rows)])
logger.debug("display shape: %s", display.shape) logger.debug("display shape: %s", display.shape)
return Image.fromarray(display) return Image.fromarray(display)
def load_latest_preview(self, thumbnail_size: int, frame_dims: Tuple[int, int]) -> bool: def load_latest_preview(self, thumbnail_size: int, frame_dims: tuple[int, int]) -> bool:
""" Load the latest preview image for extract and convert. """ Load the latest preview image for extract and convert.
Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails
@ -524,7 +520,7 @@ class Images():
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
self._pathpreview = os.path.join(PATHCACHE, "preview") self._pathpreview = os.path.join(PATHCACHE, "preview")
self._pathoutput: Optional[str] = None self._pathoutput: str | None = None
self._batch_mode = False self._batch_mode = False
self._preview_train = PreviewTrain(self._pathpreview) self._preview_train = PreviewTrain(self._pathpreview)
self._preview_extract = PreviewExtract(self._pathpreview) self._preview_extract = PreviewExtract(self._pathpreview)
@ -542,7 +538,7 @@ class Images():
return self._preview_extract return self._preview_extract
@property @property
def icons(self) -> Dict[str, ImageTk.PhotoImage]: def icons(self) -> dict[str, ImageTk.PhotoImage]:
""" dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon """ dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon
name (`str`) the value is the icon sized and formatted for display name (`str`) the value is the icon sized and formatted for display
(:class:`PIL.ImageTK.PhotoImage`). (:class:`PIL.ImageTK.PhotoImage`).
@ -557,7 +553,7 @@ class Images():
return self._icons return self._icons
@staticmethod @staticmethod
def _load_icons() -> Dict[str, ImageTk.PhotoImage]: def _load_icons() -> dict[str, ImageTk.PhotoImage]:
""" Scan the icons cache folder and load the icons into :attr:`icons` for retrieval """ Scan the icons cache folder and load the icons into :attr:`icons` for retrieval
throughout the GUI. throughout the GUI.
@ -569,7 +565,7 @@ class Images():
""" """
size = get_config().user_config_dict.get("icon_size", 16) size = get_config().user_config_dict.get("icon_size", 16)
size = int(round(size * get_config().scaling_factor)) size = int(round(size * get_config().scaling_factor))
icons: Dict[str, ImageTk.PhotoImage] = {} icons: dict[str, ImageTk.PhotoImage] = {}
pathicons = os.path.join(PATHCACHE, "icons") pathicons = os.path.join(PATHCACHE, "icons")
for fname in os.listdir(pathicons): for fname in os.listdir(pathicons):
name, ext = os.path.splitext(fname) name, ext = os.path.splitext(fname)
@ -609,12 +605,12 @@ class PreviewTrigger():
""" """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__) logger.debug("Initializing: %s", self.__class__.__name__)
self._trigger_files = dict(update=os.path.join(PATHCACHE, ".preview_trigger"), self._trigger_files = {"update": os.path.join(PATHCACHE, ".preview_trigger"),
mask_toggle=os.path.join(PATHCACHE, ".preview_mask_toggle")) "mask_toggle": os.path.join(PATHCACHE, ".preview_mask_toggle")}
logger.debug("Initialized: %s (trigger_files: %s)", logger.debug("Initialized: %s (trigger_files: %s)",
self.__class__.__name__, self._trigger_files) self.__class__.__name__, self._trigger_files)
def set(self, trigger_type: Literal["update", "mask_toggle"]): def set(self, trigger_type: T.Literal["update", "mask_toggle"]):
""" Place the trigger file into the cache folder """ Place the trigger file into the cache folder
Parameters Parameters
@ -629,7 +625,7 @@ class PreviewTrigger():
pass pass
logger.debug("Set preview trigger: %s", trigger) logger.debug("Set preview trigger: %s", trigger)
def clear(self, trigger_type: Optional[Literal["update", "mask_toggle"]] = None) -> None: def clear(self, trigger_type: T.Literal["update", "mask_toggle"] | None = None) -> None:
""" Remove the trigger file from the cache folder. """ Remove the trigger file from the cache folder.
Parameters Parameters

View file

@ -1,15 +1,17 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """ """ Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """
from __future__ import annotations
import logging import logging
import sys import sys
import typing as T
from threading import Event, Thread from threading import Event, Thread
from typing import (Any, Callable, cast, Dict, Optional, Tuple, Type, TYPE_CHECKING)
from queue import Queue from queue import Queue
from .config import get_config from .config import get_config
if TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType from types import TracebackType
from lib.multithreading import _ErrorType from lib.multithreading import _ErrorType
@ -31,15 +33,15 @@ class LongRunningTask(Thread):
cursor in the correct location. Default: ``None``. cursor in the correct location. Default: ``None``.
""" """
_target: Callable _target: Callable
_args: Tuple _args: tuple
_kwargs: Dict[str, Any] _kwargs: dict[str, T.Any]
_name: str _name: str
def __init__(self, def __init__(self,
target: Optional[Callable] = None, target: Callable | None = None,
name: Optional[str] = None, name: str | None = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict[str, Any]] = None, kwargs: dict[str, T.Any] | None = None,
*, *,
daemon: bool = True, daemon: bool = True,
widget=None): widget=None):
@ -48,7 +50,7 @@ class LongRunningTask(Thread):
daemon) daemon)
super().__init__(target=target, name=name, args=args, kwargs=kwargs, super().__init__(target=target, name=name, args=args, kwargs=kwargs,
daemon=daemon) daemon=daemon)
self.err: "_ErrorType" = None self.err: _ErrorType = None
self._widget = widget self._widget = widget
self._config = get_config() self._config = get_config()
self._config.set_cursor_busy(widget=self._widget) self._config.set_cursor_busy(widget=self._widget)
@ -70,8 +72,8 @@ class LongRunningTask(Thread):
retval = self._target(*self._args, **self._kwargs) retval = self._target(*self._args, **self._kwargs)
self._queue.put(retval) self._queue.put(retval)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
self.err = cast(Tuple[Type[BaseException], BaseException, "TracebackType"], self.err = T.cast(tuple[type[BaseException], BaseException, "TracebackType"],
sys.exc_info()) sys.exc_info())
assert self.err is not None assert self.err is not None
logger.debug("Error in thread (%s): %s", self._name, logger.debug("Error in thread (%s): %s", self._name,
self.err[1].with_traceback(self.err[2])) self.err[1].with_traceback(self.err[2]))
@ -81,7 +83,7 @@ class LongRunningTask(Thread):
# an argument that has a member that points to the thread. # an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs del self._target, self._args, self._kwargs
def get_result(self) -> Any: def get_result(self) -> T.Any:
""" Return the result from the given task. """ Return the result from the given task.
Returns Returns

View file

@ -1,17 +1,17 @@
#!/usr/bin python3 #!/usr/bin python3
""" Utilities for working with images and videos """ """ Utilities for working with images and videos """
from __future__ import annotations
import logging import logging
import re import re
import subprocess import subprocess
import os import os
import struct import struct
import sys import sys
import typing as T
from ast import literal_eval from ast import literal_eval
from bisect import bisect from bisect import bisect
from concurrent import futures from concurrent import futures
from typing import Optional, TYPE_CHECKING, Union
from zlib import crc32 from zlib import crc32
import cv2 import cv2
@ -24,7 +24,7 @@ from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager, QueueEmpty from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderDict from lib.align.alignments import PNGHeaderDict
logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -558,7 +558,7 @@ def update_existing_metadata(filename, metadata):
def encode_image(image: np.ndarray, def encode_image(image: np.ndarray,
extension: str, extension: str,
metadata: Optional["PNGHeaderDict"] = None) -> bytes: metadata: PNGHeaderDict | None = None) -> bytes:
""" Encode an image. """ Encode an image.
Parameters Parameters
@ -1433,8 +1433,8 @@ class ImagesSaver(ImageIO):
def _save(self, def _save(self,
filename: str, filename: str,
image: Union[bytes, np.ndarray], image: bytes | np.ndarray,
sub_folder: Optional[str]) -> None: sub_folder: str | None) -> None:
""" Save a single image inside a ThreadPoolExecutor """ Save a single image inside a ThreadPoolExecutor
Parameters Parameters
@ -1468,8 +1468,8 @@ class ImagesSaver(ImageIO):
def save(self, def save(self,
filename: str, filename: str,
image: Union[bytes, np.ndarray], image: bytes | np.ndarray,
sub_folder: Optional[str] = None) -> None: sub_folder: str | None = None) -> None:
""" Save the given image in the background thread """ Save the given image in the background thread
Ensure that :func:`close` is called once all save operations are complete. Ensure that :func:`close` is called once all save operations are complete.

View file

@ -117,7 +117,7 @@ class ColorSpaceConvert(): # pylint:disable=too-few-public-methods
self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32") self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32")
@classmethod @classmethod
def _get_rgb_xyz_map(cls) -> T.Tuple[Tensor, Tensor]: def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]:
""" Obtain the mapping and inverse mapping for rgb to xyz color space conversion. """ Obtain the mapping and inverse mapping for rgb to xyz color space conversion.
Returns Returns

View file

@ -11,7 +11,17 @@ import time
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Union
# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
def _patched_format(self, record):
""" Autograph tf-2.10 has a bug with the 3.10 version of logging.PercentStyle._format(). It is
non-critical but spits out warnings. This is the Python 3.9 version of the function and should
be removed once fixed """
return self._fmt % record.__dict__ # pylint:disable=protected-access
setattr(logging.PercentStyle, "_format", _patched_format)
class FaceswapLogger(logging.Logger): class FaceswapLogger(logging.Logger):
@ -76,11 +86,11 @@ class ColoredFormatter(logging.Formatter):
def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None: def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None:
super().__init__(fmt, **kwargs) super().__init__(fmt, **kwargs)
self._use_color = self._get_color_compatibility() self._use_color = self._get_color_compatibility()
self._level_colors = dict(CRITICAL="\033[31m", # red self._level_colors = {"CRITICAL": "\033[31m", # red
ERROR="\033[31m", # red "ERROR": "\033[31m", # red
WARNING="\033[33m", # yellow "WARNING": "\033[33m", # yellow
INFO="\033[32m", # green "INFO": "\033[32m", # green
VERBOSE="\033[34m") # blue "VERBOSE": "\033[34m"} # blue
self._default_color = "\033[0m" self._default_color = "\033[0m"
self._newline_padding = self._get_newline_padding(pad_newlines, fmt) self._newline_padding = self._get_newline_padding(pad_newlines, fmt)
@ -412,7 +422,7 @@ def _file_handler(loglevel,
return handler return handler
def _stream_handler(loglevel: int, is_gui: bool) -> Union[logging.StreamHandler, TqdmHandler]: def _stream_handler(loglevel: int, is_gui: bool) -> logging.StreamHandler | TqdmHandler:
""" Add a stream handler for the current Faceswap session. The stream handler will only ever """ Add a stream handler for the current Faceswap session. The stream handler will only ever
output at a maximum of VERBOSE level to avoid spamming the console. output at a maximum of VERBOSE level to avoid spamming the console.

View file

@ -1,8 +1,6 @@
""" Auto clipper for clipping gradients. """ """ Auto clipper for clipping gradients. """
from typing import List import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp
class AutoClipper(): # pylint:disable=too-few-public-methods class AutoClipper(): # pylint:disable=too-few-public-methods
@ -22,12 +20,56 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
original paper: https://arxiv.org/abs/2007.14469 original paper: https://arxiv.org/abs/2007.14469
""" """
def __init__(self, clip_percentile: int, history_size: int = 10000): def __init__(self, clip_percentile: int, history_size: int = 10000):
self._clip_percentile = clip_percentile self._clip_percentile = tf.cast(clip_percentile, tf.float64)
self._grad_history = tf.Variable(tf.zeros(history_size), trainable=False) self._grad_history = tf.Variable(tf.zeros(history_size), trainable=False)
self._index = tf.Variable(0, trainable=False) self._index = tf.Variable(0, trainable=False)
self._history_size = history_size self._history_size = history_size
def __call__(self, grads_and_vars: List[tf.Tensor]) -> List[tf.Tensor]: def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor:
""" Compute the clip percentile of the gradient history
Parameters
----------
grad_history: :class:`tensorflow.Tensor`
Tge gradient history to calculate the clip percentile for
Returns
-------
:class:`tensorflow.Tensor`
A rank(:attr:`clip_percentile`) `Tensor`
Notes
-----
Adapted from
https://github.com/tensorflow/probability/blob/r0.14/tensorflow_probability/python/stats/quantiles.py
to remove reliance on full tensorflow_probability libraray
"""
with tf.name_scope("percentile"):
frac_at_q_or_below = self._clip_percentile / 100.
sorted_hist = tf.sort(grad_history, axis=-1, direction="ASCENDING")
num = tf.cast(tf.shape(grad_history)[-1], tf.float64)
# get indices
indices = tf.round((num - 1) * frac_at_q_or_below)
indices = tf.clip_by_value(tf.cast(indices, tf.int32),
0,
tf.shape(grad_history)[-1] - 1)
gathered_hist = tf.gather(sorted_hist, indices, axis=-1)
# Propagate NaNs. Apparently tf.is_nan doesn't like other dtypes
nan_batch_members = tf.reduce_any(tf.math.is_nan(grad_history), axis=None)
right_rank_matched_shape = tf.pad(tf.shape(nan_batch_members),
paddings=[[0, tf.rank(self._clip_percentile)]],
constant_values=1)
nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape)
nan = np.array(np.nan, gathered_hist.dtype.as_numpy_dtype)
gathered_hist = tf.where(nan_batch_members, nan, gathered_hist)
return gathered_hist
def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]:
""" Call the AutoClip function. """ Call the AutoClip function.
Parameters Parameters
@ -40,8 +82,7 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
assign_idx = tf.math.mod(self._index, self._history_size) assign_idx = tf.math.mod(self._index, self._history_size)
self._grad_history = self._grad_history[assign_idx].assign(total_norm) self._grad_history = self._grad_history[assign_idx].assign(total_norm)
self._index = self._index.assign_add(1) self._index = self._index.assign_add(1)
clip_value = tfp.stats.percentile(self._grad_history[: self._index], clip_value = self._percentile(self._grad_history[: self._index])
q=self._clip_percentile)
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars] return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]
@classmethod @classmethod

View file

@ -1,9 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Custom Feature Map Loss Functions for faceswap.py """ """ Custom Feature Map Loss Functions for faceswap.py """
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
import typing as T
from typing import Any, Callable, Dict, Optional, List, Tuple
# Ignore linting errors from Tensorflow's thoroughly broken import system # Ignore linting errors from Tensorflow's thoroughly broken import system
import tensorflow as tf import tensorflow as tf
@ -17,6 +17,9 @@ import numpy as np
from lib.model.nets import AlexNet, SqueezeNet from lib.model.nets import AlexNet, SqueezeNet
from lib.utils import GetModel from lib.utils import GetModel
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,10 +42,10 @@ class NetInfo:
""" """
model_id: int = 0 model_id: int = 0
model_name: str = "" model_name: str = ""
net: Optional[Callable] = None net: Callable | None = None
init_kwargs: Dict[str, Any] = field(default_factory=dict) init_kwargs: dict[str, T.Any] = field(default_factory=dict)
needs_init: bool = True needs_init: bool = True
outputs: List[Layer] = field(default_factory=list) outputs: list[Layer] = field(default_factory=list)
class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
@ -67,7 +70,7 @@ class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
logger.debug("Initialized: %s ", self.__class__.__name__) logger.debug("Initialized: %s ", self.__class__.__name__)
@property @property
def _nets(self) -> Dict[str, NetInfo]: def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net.""" """ :class:`NetInfo`: The Information about the requested net."""
return { return {
"alex": NetInfo(model_id=15, "alex": NetInfo(model_id=15,
@ -176,7 +179,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@property @property
def _nets(self) -> Dict[str, NetInfo]: def _nets(self) -> dict[str, NetInfo]:
""" :class:`NetInfo`: The Information about the requested net.""" """ :class:`NetInfo`: The Information about the requested net."""
return { return {
"alex": NetInfo(model_id=18, "alex": NetInfo(model_id=18,
@ -186,7 +189,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
"vgg16": NetInfo(model_id=20, "vgg16": NetInfo(model_id=20,
model_name="vgg16_lpips_v1.h5")} model_name="vgg16_lpips_v1.h5")}
def _linear_block(self, net_output_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
""" Build a linear block for a trunk network output. """ Build a linear block for a trunk network output.
Parameters Parameters
@ -319,7 +322,7 @@ class LPIPSLoss(): # pylint:disable=too-few-public-methods
tf.keras.mixed_precision.set_global_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
def _process_diffs(self, inputs: List[tf.Tensor]) -> List[tf.Tensor]: def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]:
""" Perform processing on the Trunk Network outputs. """ Perform processing on the Trunk Network outputs.
If :attr:`use_ldip` is enabled, process the diff values through the linear network, If :attr:`use_ldip` is enabled, process the diff values through the linear network,

View file

@ -1,10 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Custom Loss Functions for faceswap.py """ """ Custom Loss Functions for faceswap.py """
from __future__ import absolute_import from __future__ import annotations
import logging import logging
from typing import Callable, List, Tuple import typing as T
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -13,6 +12,9 @@ import tensorflow as tf
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
from tensorflow.keras import backend as K # pylint:disable=import-error from tensorflow.keras import backend as K # pylint:disable=import-error
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,7 +63,7 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
self._ave_spectrum = ave_spectrum self._ave_spectrum = ave_spectrum
self._log_matrix = log_matrix self._log_matrix = log_matrix
self._batch_matrix = batch_matrix self._batch_matrix = batch_matrix
self._dims: Tuple[int, int] = (0, 0) self._dims: tuple[int, int] = (0, 0)
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor: def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
""" Crop the incoming batch of images into patches as defined by :attr:`_patch_factor. """ Crop the incoming batch of images into patches as defined by :attr:`_patch_factor.
@ -470,7 +472,7 @@ class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid") retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
return retval return retval
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]: def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> list[tf.Tensor]:
""" Obtain the Laplacian Pyramid. """ Obtain the Laplacian Pyramid.
Parameters Parameters
@ -564,9 +566,9 @@ class LossWrapper(tf.keras.losses.Loss):
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__) logger.debug("Initializing: %s", self.__class__.__name__)
super().__init__(name="LossWrapper") super().__init__(name="LossWrapper")
self._loss_functions: List[compile_utils.LossesContainer] = [] self._loss_functions: list[compile_utils.LossesContainer] = []
self._loss_weights: List[float] = [] self._loss_weights: list[float] = []
self._mask_channels: List[int] = [] self._mask_channels: list[int] = []
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
def add_loss(self, def add_loss(self,
@ -628,7 +630,7 @@ class LossWrapper(tf.keras.losses.Loss):
y_true: tf.Tensor, y_true: tf.Tensor,
y_pred: tf.Tensor, y_pred: tf.Tensor,
mask_channel: int, mask_channel: int,
mask_prop: float = 1.0) -> Tuple[tf.Tensor, tf.Tensor]: mask_prop: float = 1.0) -> tuple[tf.Tensor, tf.Tensor]:
""" Apply the mask to the input y_true and y_pred. If a mask is not required then """ Apply the mask to the input y_true and y_pred. If a mask is not required then
return the unmasked inputs. return the unmasked inputs.

View file

@ -2,9 +2,7 @@
""" TF Keras implementation of Perceptual Loss Functions for faceswap.py """ """ TF Keras implementation of Perceptual Loss Functions for faceswap.py """
import logging import logging
import sys import typing as T
from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -14,11 +12,6 @@ from tensorflow.keras import backend as K # pylint:disable=import-error
from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,7 +94,7 @@ class DSSIMObjective(): # pylint:disable=too-few-public-methods
""" """
return K.depthwise_conv2d(image, kernel, strides=(1, 1), padding="valid") return K.depthwise_conv2d(image, kernel, strides=(1, 1), padding="valid")
def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
""" Obtain the structural similarity between a batch of true and predicted images. """ Obtain the structural similarity between a batch of true and predicted images.
Parameters Parameters
@ -330,8 +323,8 @@ class LDRFLIPLoss(): # pylint:disable=too-few-public-methods
lower_threshold_exponent: float = 0.4, lower_threshold_exponent: float = 0.4,
upper_threshold_exponent: float = 0.95, upper_threshold_exponent: float = 0.95,
epsilon: float = 1e-15, epsilon: float = 1e-15,
pixels_per_degree: Optional[float] = None, pixels_per_degree: float | None = None,
color_order: Literal["bgr", "rgb"] = "bgr") -> None: color_order: T.Literal["bgr", "rgb"] = "bgr") -> None:
logger.debug("Initializing: %s (computed_distance_exponent '%s', feature_exponent: %s, " logger.debug("Initializing: %s (computed_distance_exponent '%s', feature_exponent: %s, "
"lower_threshold_exponent: %s, upper_threshold_exponent: %s, epsilon: %s, " "lower_threshold_exponent: %s, upper_threshold_exponent: %s, epsilon: %s, "
"pixels_per_degree: %s, color_order: %s)", self.__class__.__name__, "pixels_per_degree: %s, color_order: %s)", self.__class__.__name__,
@ -525,7 +518,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
self._spatial_filters, self._radius = self._generate_spatial_filters() self._spatial_filters, self._radius = self._generate_spatial_filters()
self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb") self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb")
def _generate_spatial_filters(self) -> Tuple[tf.Tensor, int]: def _generate_spatial_filters(self) -> tuple[tf.Tensor, int]:
""" Generates spatial contrast sensitivity filters with width depending on the number of """ Generates spatial contrast sensitivity filters with width depending on the number of
pixels per degree of visual angle of the observer for channels "A", "RG" and "BY" pixels per degree of visual angle of the observer for channels "A", "RG" and "BY"
@ -559,7 +552,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
b1_rg: float, b1_rg: float,
b2_rg: float, b2_rg: float,
b1_by: float, b1_by: float,
b2_by: float) -> Tuple[np.ndarray, int]: b2_by: float) -> tuple[np.ndarray, int]:
""" TODO docstring """ """ TODO docstring """
max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by]) max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by])
delta_x = 1.0 / self._pixels_per_degree delta_x = 1.0 / self._pixels_per_degree
@ -570,7 +563,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
return domain, radius return domain, radius
@classmethod @classmethod
def _generate_weights(cls, channel: Dict[str, float], domain: np.ndarray) -> tf.Tensor: def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> tf.Tensor:
""" TODO docstring """ """ TODO docstring """
a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"] a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"]
grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) + grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) +
@ -694,7 +687,7 @@ class MSSIMLoss(): # pylint:disable=too-few-public-methods
filter_size: int = 11, filter_size: int = 11,
filter_sigma: float = 1.5, filter_sigma: float = 1.5,
max_value: float = 1.0, max_value: float = 1.0,
power_factors: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) power_factors: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
) -> None: ) -> None:
self.filter_size = filter_size self.filter_size = filter_size
self.filter_sigma = filter_sigma self.filter_sigma = filter_sigma

View file

@ -31,7 +31,7 @@ class _net(): # pylint:disable=too-few-public-methods
The input shape for the model. Default: ``None`` The input shape for the model. Default: ``None``
""" """
def __init__(self, def __init__(self,
input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None: input_shape: tuple[int, int, int] | None = None) -> None:
logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape) logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape)
self._input_shape = (None, None, 3) if input_shape is None else input_shape self._input_shape = (None, None, 3) if input_shape is None else input_shape
assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, ( assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, (
@ -56,7 +56,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
input_shape, Tuple, optional input_shape, Tuple, optional
The input shape for the model. Default: ``None`` The input shape for the model. Default: ``None``
""" """
def __init__(self, input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None: def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None:
super().__init__(input_shape) super().__init__(input_shape)
self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch
self._filters = [64, 192, 384, 256, 256] # Filters at each block self._filters = [64, 192, 384, 256, 256] # Filters at each block
@ -108,7 +108,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
name=name)(var_x) name=name)(var_x)
return var_x return var_x
def __call__(self) -> Model: def __call__(self) -> tf.keras.models.Model:
""" Create the AlexNet Model """ Create the AlexNet Model
Returns Returns
@ -189,7 +189,7 @@ class SqueezeNet(_net): # pylint:disable=too-few-public-methods
name=f"{name}.expand3x3")(squeezed) name=f"{name}.expand3x3")(squeezed)
return layers.Concatenate(axis=-1, name=name)([expand1, expand3]) return layers.Concatenate(axis=-1, name=name)([expand1, expand3])
def __call__(self) -> Model: def __call__(self) -> tf.keras.models.Model:
""" Create the SqueezeNet Model """ Create the SqueezeNet Model
Returns Returns

View file

@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG: dict = {} _CONFIG: dict = {}
_NAMES: T.Dict[str, int] = {} _NAMES: dict[str, int] = {}
def set_config(configuration: dict) -> None: def set_config(configuration: dict) -> None:
@ -189,7 +189,7 @@ class Conv2DOutput(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int]], kernel_size: int | tuple[int],
activation: str = "sigmoid", activation: str = "sigmoid",
padding: str = "same", **kwargs) -> None: padding: str = "same", **kwargs) -> None:
self._name = kwargs.pop("name") if "name" in kwargs else _get_name( self._name = kwargs.pop("name") if "name" in kwargs else _get_name(
@ -265,11 +265,11 @@ class Conv2DBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 5, kernel_size: int | tuple[int, int] = 5,
strides: T.Union[int, T.Tuple[int, int]] = 2, strides: int | tuple[int, int] = 2,
padding: str = "same", padding: str = "same",
normalization: T.Optional[str] = None, normalization: str | None = None,
activation: T.Optional[str] = "leakyrelu", activation: str | None = "leakyrelu",
use_depthwise: bool = False, use_depthwise: bool = False,
relu_alpha: float = 0.1, relu_alpha: float = 0.1,
**kwargs) -> None: **kwargs) -> None:
@ -362,8 +362,8 @@ class SeparableConv2DBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 5, kernel_size: int | tuple[int, int] = 5,
strides: T.Union[int, T.Tuple[int, int]] = 2, **kwargs) -> None: strides: int | tuple[int, int] = 2, **kwargs) -> None:
self._name = _get_name(f"separableconv2d_{filters}") self._name = _get_name(f"separableconv2d_{filters}")
logger.debug("name: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)", logger.debug("name: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)",
self._name, filters, kernel_size, strides, kwargs) self._name, filters, kernel_size, strides, kwargs)
@ -434,11 +434,11 @@ class UpscaleBlock(): # pylint:disable=too-few-public-methods
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3, kernel_size: int | tuple[int, int] = 3,
padding: str = "same", padding: str = "same",
scale_factor: int = 2, scale_factor: int = 2,
normalization: T.Optional[str] = None, normalization: str | None = None,
activation: T.Optional[str] = "leakyrelu", activation: str | None = "leakyrelu",
**kwargs) -> None: **kwargs) -> None:
self._name = _get_name(f"upscale_{filters}") self._name = _get_name(f"upscale_{filters}")
logger.debug("name: %s. filters: %s, kernel_size: %s, padding: %s, scale_factor: %s, " logger.debug("name: %s. filters: %s, kernel_size: %s, padding: %s, scale_factor: %s, "
@ -521,9 +521,9 @@ class Upscale2xBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3, kernel_size: int | tuple[int, int] = 3,
padding: str = "same", padding: str = "same",
activation: T.Optional[str] = "leakyrelu", activation: str | None = "leakyrelu",
interpolation: str = "bilinear", interpolation: str = "bilinear",
sr_ratio: float = 0.5, sr_ratio: float = 0.5,
scale_factor: int = 2, scale_factor: int = 2,
@ -615,9 +615,9 @@ class UpscaleResizeImagesBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3, kernel_size: int | tuple[int, int] = 3,
padding: str = "same", padding: str = "same",
activation: T.Optional[str] = "leakyrelu", activation: str | None = "leakyrelu",
scale_factor: int = 2, scale_factor: int = 2,
interpolation: str = "bilinear") -> None: interpolation: str = "bilinear") -> None:
self._name = _get_name(f"upscale_ri_{filters}") self._name = _get_name(f"upscale_ri_{filters}")
@ -700,9 +700,9 @@ class UpscaleDNYBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3, kernel_size: int | tuple[int, int] = 3,
padding: str = "same", padding: str = "same",
activation: T.Optional[str] = "leakyrelu", activation: str | None = "leakyrelu",
size: int = 2, size: int = 2,
interpolation: str = "bilinear", interpolation: str = "bilinear",
**kwargs) -> None: **kwargs) -> None:
@ -757,7 +757,7 @@ class ResidualBlock(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
filters: int, filters: int,
kernel_size: T.Union[int, T.Tuple[int, int]] = 3, kernel_size: int | tuple[int, int] = 3,
padding: str = "same", padding: str = "same",
**kwargs) -> None: **kwargs) -> None:
self._name = _get_name(f"residual_{filters}") self._name = _get_name(f"residual_{filters}")

View file

@ -1,9 +1,9 @@
#!/usr/bin python3 #!/usr/bin python3
""" Settings manager for Keras Backend """ """ Settings manager for Keras Backend """
from __future__ import annotations
from contextlib import nullcontext from contextlib import nullcontext
import logging import logging
from typing import Callable, ContextManager, List, Optional, Union import typing as T
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -14,6 +14,9 @@ from tensorflow.keras.models import load_model as k_load_model, Model # noqa:E5
from lib.utils import get_backend from lib.utils import get_backend
if T.TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) # pylint:disable=invalid-name logger = logging.getLogger(__name__) # pylint:disable=invalid-name
@ -52,9 +55,9 @@ class KSession():
def __init__(self, def __init__(self,
name: str, name: str,
model_path: str, model_path: str,
model_kwargs: Optional[dict] = None, model_kwargs: dict | None = None,
allow_growth: bool = False, allow_growth: bool = False,
exclude_gpus: Optional[List[int]] = None, exclude_gpus: list[int] | None = None,
cpu_mode: bool = False) -> None: cpu_mode: bool = False) -> None:
logger.trace("Initializing: %s (name: %s, model_path: %s, " # type:ignore logger.trace("Initializing: %s (name: %s, model_path: %s, " # type:ignore
"model_kwargs: %s, allow_growth: %s, exclude_gpus: %s, cpu_mode: %s)", "model_kwargs: %s, allow_growth: %s, exclude_gpus: %s, cpu_mode: %s)",
@ -67,12 +70,12 @@ class KSession():
cpu_mode) cpu_mode)
self._model_path = model_path self._model_path = model_path
self._model_kwargs = {} if not model_kwargs else model_kwargs self._model_kwargs = {} if not model_kwargs else model_kwargs
self._model: Optional[Model] = None self._model: Model | None = None
logger.trace("Initialized: %s", self.__class__.__name__,) # type:ignore logger.trace("Initialized: %s", self.__class__.__name__,) # type:ignore
def predict(self, def predict(self,
feed: Union[List[np.ndarray], np.ndarray], feed: list[np.ndarray] | np.ndarray,
batch_size: Optional[int] = None) -> Union[List[np.ndarray], np.ndarray]: batch_size: int | None = None) -> list[np.ndarray] | np.ndarray:
""" Get predictions from the model. """ Get predictions from the model.
This method is a wrapper for :func:`keras.predict()` function. For Tensorflow backends This method is a wrapper for :func:`keras.predict()` function. For Tensorflow backends
@ -98,7 +101,7 @@ class KSession():
def _set_session(self, def _set_session(self,
allow_growth: bool, allow_growth: bool,
exclude_gpus: list, exclude_gpus: list,
cpu_mode: bool) -> ContextManager: cpu_mode: bool) -> T.ContextManager:
""" Sets the backend session options. """ Sets the backend session options.
For CPU backends, this hides any GPUs from Tensorflow. For CPU backends, this hides any GPUs from Tensorflow.

View file

@ -1,19 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Multithreading/processing utils for faceswap """ """ Multithreading/processing utils for faceswap """
from __future__ import annotations
import logging import logging
import typing as T
from multiprocessing import cpu_count from multiprocessing import cpu_count
import queue as Queue import queue as Queue
import sys import sys
import threading import threading
from types import TracebackType from types import TracebackType
from typing import Any, Callable, Dict, Generator, List, Tuple, Type, Optional, Set, Union
if T.TYPE_CHECKING:
from collections.abc import Callable, Generator
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_ErrorType = Optional[Union[Tuple[Type[BaseException], BaseException, TracebackType], _ErrorType: T.TypeAlias = (tuple[type[BaseException], BaseException, TracebackType] |
Tuple[Any, Any, Any]]] tuple[T.Any, T.Any, T.Any] | None)
_THREAD_NAMES: Set[str] = set() _THREAD_NAMES: set[str] = set()
def total_cpus(): def total_cpus():
@ -62,17 +65,17 @@ class FSThread(threading.Thread):
keyword arguments for the target invocation. Default: {}. keyword arguments for the target invocation. Default: {}.
""" """
_target: Callable _target: Callable
_args: Tuple _args: tuple
_kwargs: Dict[str, Any] _kwargs: dict[str, T.Any]
_name: str _name: str
def __init__(self, def __init__(self,
target: Optional[Callable] = None, target: Callable | None = None,
name: Optional[str] = None, name: str | None = None,
args: Tuple = (), args: tuple = (),
kwargs: Optional[Dict[str, Any]] = None, kwargs: dict[str, T.Any] | None = None,
*, *,
daemon: Optional[bool] = None) -> None: daemon: bool | None = None) -> None:
super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
self.err: _ErrorType = None self.err: _ErrorType = None
@ -124,7 +127,7 @@ class MultiThread():
target: Callable, target: Callable,
*args, *args,
thread_count: int = 1, thread_count: int = 1,
name: Optional[str] = None, name: str | None = None,
**kwargs) -> None: **kwargs) -> None:
self._name = _get_name(name if name else target.__name__) self._name = _get_name(name if name else target.__name__)
logger.debug("Initializing %s: (target: '%s', thread_count: %s)", logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
@ -132,7 +135,7 @@ class MultiThread():
logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore
self.daemon = True self.daemon = True
self._thread_count = thread_count self._thread_count = thread_count
self._threads: List[FSThread] = [] self._threads: list[FSThread] = []
self._target = target self._target = target
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
@ -144,7 +147,7 @@ class MultiThread():
return any(thread.err for thread in self._threads) return any(thread.err for thread in self._threads)
@property @property
def errors(self) -> List[_ErrorType]: def errors(self) -> list[_ErrorType]:
""" list: List of thread error values """ """ list: List of thread error values """
return [thread.err for thread in self._threads if thread.err] return [thread.err for thread in self._threads if thread.err]
@ -253,9 +256,9 @@ class BackgroundGenerator(MultiThread):
def __init__(self, def __init__(self,
generator: Callable, generator: Callable,
prefetch: int = 1, prefetch: int = 1,
name: Optional[str] = None, name: str | None = None,
args: Optional[Tuple] = None, args: tuple | None = None,
kwargs: Optional[Dict[str, Any]] = None) -> None: kwargs: dict[str, T.Any] | None = None) -> None:
super().__init__(name=name, target=self._run) super().__init__(name=name, target=self._run)
self.queue: Queue.Queue = Queue.Queue(prefetch) self.queue: Queue.Queue = Queue.Queue(prefetch)
self.generator = generator self.generator = generator

View file

@ -6,7 +6,6 @@
import logging import logging
import threading import threading
from typing import Dict
from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
from time import sleep from time import sleep
@ -45,7 +44,7 @@ class _QueueManager():
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
self.shutdown = threading.Event() self.shutdown = threading.Event()
self.queues: Dict[str, EventQueue] = {} self.queues: dict[str, EventQueue] = {}
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str: def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str:

View file

@ -6,8 +6,8 @@ import locale
import os import os
import platform import platform
import sys import sys
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
from typing import List, Optional
import psutil import psutil
@ -21,14 +21,14 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
def __init__(self) -> None: def __init__(self) -> None:
self._state_file = _State().state_file self._state_file = _State().state_file
self._configs = _Configs().configs self._configs = _Configs().configs
self._system = dict(platform=platform.platform(), self._system = {"platform": platform.platform(),
system=platform.system().lower(), "system": platform.system().lower(),
machine=platform.machine(), "machine": platform.machine(),
release=platform.release(), "release": platform.release(),
processor=platform.processor(), "processor": platform.processor(),
cpu_count=os.cpu_count()) "cpu_count": os.cpu_count()}
self._python = dict(implementation=platform.python_implementation(), self._python = {"implementation": platform.python_implementation(),
version=platform.python_version()) "version": platform.python_version()}
self._gpu = self._get_gpu_info() self._gpu = self._get_gpu_info()
self._cuda_check = CudaCheck() self._cuda_check = CudaCheck()
@ -66,7 +66,7 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
(hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)) (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix))
else: else:
prefix = os.path.dirname(sys.prefix) prefix = os.path.dirname(sys.prefix)
retval = (os.path.basename(prefix) == "envs") retval = os.path.basename(prefix) == "envs"
return retval return retval
@property @property
@ -295,7 +295,7 @@ class _Configs(): # pylint:disable=too-few-public-methods
except FileNotFoundError: except FileNotFoundError:
return "" return ""
def _parse_configs(self, config_files: List[str]) -> str: def _parse_configs(self, config_files: list[str]) -> str:
""" Parse the given list of config files into a human readable format. """ Parse the given list of config files into a human readable format.
Parameters Parameters
@ -399,7 +399,7 @@ class _State(): # pylint:disable=too-few-public-methods
return len(sys.argv) > 1 and sys.argv[1].lower() == "train" return len(sys.argv) > 1 and sys.argv[1].lower() == "train"
@staticmethod @staticmethod
def _get_arg(*args: str) -> Optional[str]: def _get_arg(*args: str) -> str | None:
""" Obtain the value for a given command line option from sys.argv. """ Obtain the value for a given command line option from sys.argv.
Returns Returns

View file

@ -1,16 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Package for handling alignments files, detected faces and aligned faces along with their """ Package for handling alignments files, detected faces and aligned faces along with their
associated objects. """ associated objects. """
from __future__ import annotations
from typing import Type, TYPE_CHECKING import typing as T
from .augmentation import ImageAugmentation from .augmentation import ImageAugmentation
from .generator import PreviewDataGenerator, TrainingDataGenerator from .generator import PreviewDataGenerator, TrainingDataGenerator
from .preview_cv import PreviewBuffer, TriggerType from .preview_cv import PreviewBuffer, TriggerType
if TYPE_CHECKING: if T.TYPE_CHECKING:
from .preview_cv import PreviewBase from .preview_cv import PreviewBase
Preview: Type[PreviewBase] Preview: type[PreviewBase]
try: try:
from .preview_tk import PreviewTk as Preview from .preview_tk import PreviewTk as Preview

View file

@ -1,8 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Processes the augmentation of images for feeding into a Faceswap model. """ """ Processes the augmentation of images for feeding into a Faceswap model. """
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from typing import Dict, Tuple, TYPE_CHECKING import typing as T
import cv2 import cv2
import numexpr as ne import numexpr as ne
@ -11,7 +12,7 @@ from scipy.interpolate import griddata
from lib.image import batch_convert_color from lib.image import batch_convert_color
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.config import ConfigValueType from lib.config import ConfigValueType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -56,7 +57,7 @@ class AugConstants:
transform_zoom: float transform_zoom: float
transform_shift: float transform_shift: float
warp_maps: np.ndarray warp_maps: np.ndarray
warp_pad: Tuple[int, int] warp_pad: tuple[int, int]
warp_slices: slice warp_slices: slice
warp_lm_edge_anchors: np.ndarray warp_lm_edge_anchors: np.ndarray
warp_lm_grids: np.ndarray warp_lm_grids: np.ndarray
@ -79,7 +80,7 @@ class ImageAugmentation():
def __init__(self, def __init__(self,
batchsize: int, batchsize: int,
processing_size: int, processing_size: int,
config: Dict[str, "ConfigValueType"]) -> None: config: dict[str, ConfigValueType]) -> None:
logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, " logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, "
"config: %s)", "config: %s)",
self.__class__.__name__, batchsize, processing_size, config) self.__class__.__name__, batchsize, processing_size, config)
@ -332,7 +333,7 @@ class ImageAugmentation():
slices = self._constants.warp_slices slices = self._constants.warp_slices
rands = np.random.normal(size=(self._batchsize, 2, 5, 5), rands = np.random.normal(size=(self._batchsize, 2, 5, 5),
scale=self._warp_scale).astype("float32") scale=self._warp_scale).astype("float32")
batch_maps = ne.evaluate("m + r", local_dict=dict(m=self._constants.warp_maps, r=rands)) batch_maps = ne.evaluate("m + r", local_dict={"m": self._constants.warp_maps, "r": rands})
batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices] batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices]
for map_ in maps] for map_ in maps]
for maps in batch_maps]) for maps in batch_maps])

View file

@ -1,11 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Holds the data cache for training data generators """ """ Holds the data cache for training data generators """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
from threading import Lock from threading import Lock
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
@ -16,25 +16,19 @@ from lib.align.aligned_face import CenteringType
from lib.image import read_image_batch, read_image_meta_batch from lib.image import read_image_batch, read_image_meta_batch
from lib.utils import FaceswapError from lib.utils import FaceswapError
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict
from lib.config import ConfigValueType from lib.config import ConfigValueType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_FACE_CACHES: dict[str, "_Cache"] = {}
_FACE_CACHES: Dict[str, "_Cache"] = {}
def get_cache(side: Literal["a", "b"], def get_cache(side: T.Literal["a", "b"],
filenames: Optional[List[str]] = None, filenames: list[str] | None = None,
config: Optional[Dict[str, "ConfigValueType"]] = None, config: dict[str, ConfigValueType] | None = None,
size: Optional[int] = None, size: int | None = None,
coverage_ratio: Optional[float] = None) -> "_Cache": coverage_ratio: float | None = None) -> "_Cache":
""" Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then """ Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then
create it. create it.
@ -120,24 +114,24 @@ class _Cache():
The coverage ratio that the model is using. The coverage ratio that the model is using.
""" """
def __init__(self, def __init__(self,
filenames: List[str], filenames: list[str],
config: Dict[str, "ConfigValueType"], config: dict[str, ConfigValueType],
size: int, size: int,
coverage_ratio: float) -> None: coverage_ratio: float) -> None:
logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)", logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)",
self.__class__.__name__, len(filenames), size, coverage_ratio) self.__class__.__name__, len(filenames), size, coverage_ratio)
self._lock = Lock() self._lock = Lock()
self._cache_info = dict(cache_full=False, has_reset=False) self._cache_info = {"cache_full": False, "has_reset": False}
self._partially_loaded: List[str] = [] self._partially_loaded: list[str] = []
self._image_count = len(filenames) self._image_count = len(filenames)
self._cache: Dict[str, DetectedFace] = {} self._cache: dict[str, DetectedFace] = {}
self._aligned_landmarks: Dict[str, np.ndarray] = {} self._aligned_landmarks: dict[str, np.ndarray] = {}
self._extract_version = 0.0 self._extract_version = 0.0
self._size = size self._size = size
assert config["centering"] in get_args(CenteringType) assert config["centering"] in T.get_args(CenteringType)
self._centering: CenteringType = cast(CenteringType, config["centering"]) self._centering: CenteringType = T.cast(CenteringType, config["centering"])
self._config = config self._config = config
self._coverage_ratio = coverage_ratio self._coverage_ratio = coverage_ratio
@ -153,7 +147,7 @@ class _Cache():
return self._cache_info["cache_full"] return self._cache_info["cache_full"]
@property @property
def aligned_landmarks(self) -> Dict[str, np.ndarray]: def aligned_landmarks(self) -> dict[str, np.ndarray]:
""" dict: The filename as key, aligned landmarks as value. """ """ dict: The filename as key, aligned landmarks as value. """
# Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate # Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate
# all of the aligned landmarks for the entire cache. # all of the aligned landmarks for the entire cache.
@ -185,7 +179,7 @@ class _Cache():
self._cache_info["has_reset"] = False self._cache_info["has_reset"] = False
return retval return retval
def get_items(self, filenames: List[str]) -> List[DetectedFace]: def get_items(self, filenames: list[str]) -> list[DetectedFace]:
""" Obtain the cached items for a list of filenames. The returned list is in the same order """ Obtain the cached items for a list of filenames. The returned list is in the same order
as the provided filenames. as the provided filenames.
@ -202,7 +196,7 @@ class _Cache():
""" """
return [self._cache[os.path.basename(filename)] for filename in filenames] return [self._cache[os.path.basename(filename)] for filename in filenames]
def cache_metadata(self, filenames: List[str]) -> np.ndarray: def cache_metadata(self, filenames: list[str]) -> np.ndarray:
""" Obtain the batch with metadata for items that need caching and cache DetectedFace """ Obtain the batch with metadata for items that need caching and cache DetectedFace
objects to :attr:`_cache`. objects to :attr:`_cache`.
@ -267,7 +261,7 @@ class _Cache():
return batch return batch
def pre_fill(self, filenames: List[str], side: Literal["a", "b"]) -> None: def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None:
""" When warp to landmarks is enabled, the cache must be pre-filled, as each side needs """ When warp to landmarks is enabled, the cache must be pre-filled, as each side needs
access to the other side's alignments. access to the other side's alignments.
@ -294,7 +288,7 @@ class _Cache():
self._cache[key] = detected_face self._cache[key] = detected_face
self._partially_loaded.append(key) self._partially_loaded.append(key)
def _validate_version(self, png_meta: "PNGHeaderDict", filename: str) -> None: def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None:
""" Validate that there are not a mix of v1.0 extracted faces and v2.x faces. """ Validate that there are not a mix of v1.0 extracted faces and v2.x faces.
Parameters Parameters
@ -350,7 +344,7 @@ class _Cache():
def _load_detected_face(self, def _load_detected_face(self,
filename: str, filename: str,
alignments: "PNGHeaderAlignmentsDict") -> DetectedFace: alignments: PNGHeaderAlignmentsDict) -> DetectedFace:
""" Load a :class:`DetectedFace` object and load its associated `aligned` property. """ Load a :class:`DetectedFace` object and load its associated `aligned` property.
Parameters Parameters
@ -387,13 +381,13 @@ class _Cache():
The detected face object that holds the masks The detected face object that holds the masks
""" """
masks = [(self._get_face_mask(filename, detected_face))] masks = [(self._get_face_mask(filename, detected_face))]
for area in get_args(Literal["eye", "mouth"]): for area in T.get_args(T.Literal["eye", "mouth"]):
masks.append(self._get_localized_mask(filename, detected_face, area)) masks.append(self._get_localized_mask(filename, detected_face, area))
detected_face.store_training_masks(masks, delete_masks=True) detected_face.store_training_masks(masks, delete_masks=True)
logger.trace("Stored masks for filename: %s)", filename) # type: ignore logger.trace("Stored masks for filename: %s)", filename) # type: ignore
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> Optional[np.ndarray]: def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None:
""" Obtain the training sized face mask from the :class:`DetectedFace` for the requested """ Obtain the training sized face mask from the :class:`DetectedFace` for the requested
mask type. mask type.
@ -448,7 +442,7 @@ class _Cache():
def _get_localized_mask(self, def _get_localized_mask(self,
filename: str, filename: str,
detected_face: DetectedFace, detected_face: DetectedFace,
area: Literal["eye", "mouth"]) -> Optional[np.ndarray]: area: T.Literal["eye", "mouth"]) -> np.ndarray | None:
""" Obtain a localized mask for the given area if it is required for training. """ Obtain a localized mask for the given area if it is required for training.
Parameters Parameters
@ -486,7 +480,7 @@ class RingBuffer(): # pylint: disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
batch_size: int, batch_size: int,
image_shape: Tuple[int, int, int], image_shape: tuple[int, int, int],
buffer_size: int = 2, buffer_size: int = 2,
dtype: str = "uint8") -> None: dtype: str = "uint8") -> None:
logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, " logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, "

View file

@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Handles Data Augmentation for feeding Faceswap Models """ """ Handles Data Augmentation for feeding Faceswap Models """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
from concurrent import futures
from concurrent import futures
from random import shuffle, choice from random import shuffle, choice
from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
@ -21,18 +20,14 @@ from lib.utils import FaceswapError
from . import ImageAugmentation from . import ImageAugmentation
from .cache import get_cache, RingBuffer from .cache import get_cache, RingBuffer
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import get_args, Literal from collections.abc import Generator
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from lib.config import ConfigValueType from lib.config import ConfigValueType
from plugins.train.model._base import ModelBase from plugins.train.model._base import ModelBase
from .cache import _Cache from .cache import _Cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BatchType = Tuple[np.ndarray, List[np.ndarray]] BatchType = tuple[np.ndarray, list[np.ndarray]]
class DataGenerator(): class DataGenerator():
@ -57,10 +52,10 @@ class DataGenerator():
objects of this size from the iterator. objects of this size from the iterator.
""" """
def __init__(self, def __init__(self,
config: Dict[str, "ConfigValueType"], config: dict[str, ConfigValueType],
model: "ModelBase", model: ModelBase,
side: Literal["a", "b"], side: T.Literal["a", "b"],
images: List[str], images: list[str],
batch_size: int) -> None: batch_size: int) -> None:
logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore
"batch_size: %s, config: %s)", self.__class__.__name__, model.name, side, "batch_size: %s, config: %s)", self.__class__.__name__, model.name, side,
@ -83,11 +78,11 @@ class DataGenerator():
self._buffer = RingBuffer(batch_size, self._buffer = RingBuffer(batch_size,
(self._process_size, self._process_size, self._total_channels), (self._process_size, self._process_size, self._total_channels),
dtype="uint8") dtype="uint8")
self._face_cache: "_Cache" = get_cache(side, self._face_cache: _Cache = get_cache(side,
filenames=images, filenames=images,
config=self._config, config=self._config,
size=self._process_size, size=self._process_size,
coverage_ratio=self._coverage_ratio) coverage_ratio=self._coverage_ratio)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@property @property
@ -100,12 +95,12 @@ class DataGenerator():
channels += 1 channels += 1
mults = [area for area in ["eye", "mouth"] mults = [area for area in ["eye", "mouth"]
if cast(int, self._config[f"{area}_multiplier"]) > 1] if T.cast(int, self._config[f"{area}_multiplier"]) > 1]
if self._config["penalized_mask_loss"] and mults: if self._config["penalized_mask_loss"] and mults:
channels += len(mults) channels += len(mults)
return channels return channels
def _get_output_sizes(self, model: "ModelBase") -> List[int]: def _get_output_sizes(self, model: ModelBase) -> list[int]:
""" Obtain the size of each output tensor for the model. """ Obtain the size of each output tensor for the model.
Parameters Parameters
@ -222,7 +217,7 @@ class DataGenerator():
retval = self._process_batch(img_paths) retval = self._process_batch(img_paths)
yield retval yield retval
def _get_images_with_meta(self, filenames: List[str]) -> Tuple[np.ndarray, List[DetectedFace]]: def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]:
""" Obtain the raw face images with associated :class:`DetectedFace` objects for this """ Obtain the raw face images with associated :class:`DetectedFace` objects for this
batch. batch.
@ -253,9 +248,9 @@ class DataGenerator():
return raw_faces, detected_faces return raw_faces, detected_faces
def _crop_to_coverage(self, def _crop_to_coverage(self,
filenames: List[str], filenames: list[str],
images: np.ndarray, images: np.ndarray,
detected_faces: List[DetectedFace], detected_faces: list[DetectedFace],
batch: np.ndarray) -> None: batch: np.ndarray) -> None:
""" Crops the training image out of the full extract image based on the centering and """ Crops the training image out of the full extract image based on the centering and
coveage used in the user's configuration settings. coveage used in the user's configuration settings.
@ -286,7 +281,7 @@ class DataGenerator():
for future in futures.as_completed(proc): for future in futures.as_completed(proc):
batch[proc[future], ..., :3] = future.result() batch[proc[future], ..., :3] = future.result()
def _apply_mask(self, detected_faces: List[DetectedFace], batch: np.ndarray) -> None: def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None:
""" Applies the masks to the 4th channel of the batch. """ Applies the masks to the 4th channel of the batch.
If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1 If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1
@ -312,7 +307,7 @@ class DataGenerator():
logger.trace("side: %s, masks: %s, batch: %s", # type: ignore logger.trace("side: %s, masks: %s, batch: %s", # type: ignore
self._side, masks.shape, batch.shape) self._side, masks.shape, batch.shape)
def _process_batch(self, filenames: List[str]) -> BatchType: def _process_batch(self, filenames: list[str]) -> BatchType:
""" Prepares data for feeding through subclassed methods. """ Prepares data for feeding through subclassed methods.
If this is the first time a face has been loaded, then it's meta data is extracted from the If this is the first time a face has been loaded, then it's meta data is extracted from the
@ -345,9 +340,9 @@ class DataGenerator():
return feed, targets return feed, targets
def process_batch(self, def process_batch(self,
filenames: List[str], filenames: list[str],
images: np.ndarray, images: np.ndarray,
detected_faces: List[DetectedFace], detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType: batch: np.ndarray) -> BatchType:
""" Override for processing the batch for the current generator. """ Override for processing the batch for the current generator.
@ -391,7 +386,7 @@ class DataGenerator():
The input uint8 array The input uint8 array
""" """
return ne.evaluate("x / c", return ne.evaluate("x / c",
local_dict=dict(x=in_array, c=np.float32(255)), local_dict={"x": in_array, "c": np.float32(255)},
casting="unsafe") casting="unsafe")
@ -417,10 +412,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
objects of this size from the iterator. objects of this size from the iterator.
""" """
def __init__(self, def __init__(self,
config: Dict[str, "ConfigValueType"], config: dict[str, ConfigValueType],
model: "ModelBase", model: ModelBase,
side: Literal["a", "b"], side: T.Literal["a", "b"],
images: List[str], images: list[str],
batch_size: int) -> None: batch_size: int) -> None:
super().__init__(config, model, side, images, batch_size) super().__init__(config, model, side, images, batch_size)
self._augment_color = not model.command_line_arguments.no_augment_color self._augment_color = not model.command_line_arguments.no_augment_color
@ -434,10 +429,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
self._processing = ImageAugmentation(batch_size, self._processing = ImageAugmentation(batch_size,
self._process_size, self._process_size,
self._config) self._config)
self._nearest_landmarks: Dict[str, Tuple[str, ...]] = {} self._nearest_landmarks: dict[str, tuple[str, ...]] = {}
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def _create_targets(self, batch: np.ndarray) -> List[np.ndarray]: def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]:
""" Compile target images, with masks, for the model output sizes. """ Compile target images, with masks, for the model output sizes.
Parameters Parameters
@ -467,9 +462,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return retval return retval
def process_batch(self, def process_batch(self,
filenames: List[str], filenames: list[str],
images: np.ndarray, images: np.ndarray,
detected_faces: List[DetectedFace], detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType: batch: np.ndarray) -> BatchType:
""" Performs the augmentation and compiles target images and samples. """ Performs the augmentation and compiles target images and samples.
@ -525,7 +520,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
if self._warp_to_landmarks: if self._warp_to_landmarks:
landmarks = np.array([face.aligned.landmarks for face in detected_faces]) landmarks = np.array([face.aligned.landmarks for face in detected_faces])
batch_dst_pts = self._get_closest_match(filenames, landmarks) batch_dst_pts = self._get_closest_match(filenames, landmarks)
warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts) warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts}
else: else:
warp_kwargs = {} warp_kwargs = {}
@ -545,7 +540,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return feed, targets return feed, targets
def _get_closest_match(self, filenames: List[str], batch_src_points: np.ndarray) -> np.ndarray: def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray:
""" Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest """ Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest
matched 68 point landmarks from the opposite training set. matched 68 point landmarks from the opposite training set.
@ -563,7 +558,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
""" """
logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore
"src_points: '%s')", filenames, batch_src_points) "src_points: '%s')", filenames, batch_src_points)
lm_side: Literal["a", "b"] = "a" if self._side == "b" else "b" lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b"
other_cache = get_cache(lm_side) other_cache = get_cache(lm_side)
landmarks = other_cache.aligned_landmarks landmarks = other_cache.aligned_landmarks
@ -584,9 +579,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
return batch_dst_points return batch_dst_points
def _cache_closest_matches(self, def _cache_closest_matches(self,
filenames: List[str], filenames: list[str],
batch_src_points: np.ndarray, batch_src_points: np.ndarray,
landmarks: Dict[str, np.ndarray]) -> List[Tuple[str, ...]]: landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]:
""" Cache the nearest landmarks for this batch """ Cache the nearest landmarks for this batch
Parameters Parameters
@ -602,7 +597,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
logger.trace("Caching closest matches") # type:ignore logger.trace("Caching closest matches") # type:ignore
dst_landmarks = list(landmarks.items()) dst_landmarks = list(landmarks.items())
dst_points = np.array([lm[1] for lm in dst_landmarks]) dst_points = np.array([lm[1] for lm in dst_landmarks])
batch_closest_matches: List[Tuple[str, ...]] = [] batch_closest_matches: list[tuple[str, ...]] = []
for filename, src_points in zip(filenames, batch_src_points): for filename, src_points in zip(filenames, batch_src_points):
closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10]
@ -637,7 +632,7 @@ class PreviewDataGenerator(DataGenerator):
""" """
def _create_samples(self, def _create_samples(self,
images: np.ndarray, images: np.ndarray,
detected_faces: List[DetectedFace]) -> List[np.ndarray]: detected_faces: list[DetectedFace]) -> list[np.ndarray]:
""" Compile the 'sample' images. These are the 100% coverage images which hold the model """ Compile the 'sample' images. These are the 100% coverage images which hold the model
output in the preview window. output in the preview window.
@ -658,24 +653,25 @@ class PreviewDataGenerator(DataGenerator):
output_size = self._output_sizes[-1] output_size = self._output_sizes[-1]
full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2)) full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2))
assert self._config["centering"] in get_args(CenteringType) assert self._config["centering"] in T.get_args(CenteringType)
retval = np.empty((full_size, full_size, 3), dtype="float32") retval = np.empty((full_size, full_size, 3), dtype="float32")
retval = self._to_float32(np.array([AlignedFace(face.landmarks_xy, retval = self._to_float32(np.array([
image=images[idx], AlignedFace(face.landmarks_xy,
centering=cast(CenteringType, image=images[idx],
self._config["centering"]), centering=T.cast(CenteringType,
size=full_size, self._config["centering"]),
dtype="uint8", size=full_size,
is_aligned=True).face dtype="uint8",
for idx, face in enumerate(detected_faces)])) is_aligned=True).face
for idx, face in enumerate(detected_faces)]))
logger.trace("Processed samples: %s", retval.shape) # type: ignore logger.trace("Processed samples: %s", retval.shape) # type: ignore
return [retval] return [retval]
def process_batch(self, def process_batch(self,
filenames: List[str], filenames: list[str],
images: np.ndarray, images: np.ndarray,
detected_faces: List[DetectedFace], detected_faces: list[DetectedFace],
batch: np.ndarray) -> BatchType: batch: np.ndarray) -> BatchType:
""" Creates the full size preview images and the sub-cropped images for feeding the model's """ Creates the full size preview images and the sub-cropped images for feeding the model's
predict function. predict function.

View file

@ -4,36 +4,30 @@
If Tkinter is installed, then this will be used to manage the preview image, otherwise we If Tkinter is installed, then this will be used to manage the preview image, otherwise we
fallback to opencv's imshow fallback to opencv's imshow
""" """
from __future__ import annotations
import logging import logging
import sys import typing as T
from threading import Event, Lock from threading import Event, Lock
from time import sleep from time import sleep
from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2 import cv2
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal from collections.abc import Generator
else:
from typing import Literal
if TYPE_CHECKING:
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TriggerType = Dict[Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event] TriggerType = dict[T.Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event]
TriggerKeysType = Literal["m", "r", "s", "enter"] TriggerKeysType = T.Literal["m", "r", "s", "enter"]
TriggerNamesType = Literal["toggle_mask", "refresh", "save", "quit"] TriggerNamesType = T.Literal["toggle_mask", "refresh", "save", "quit"]
class PreviewBuffer(): class PreviewBuffer():
""" A thread safe class for holding preview images """ """ A thread safe class for holding preview images """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__) logger.debug("Initializing: %s", self.__class__.__name__)
self._images: Dict[str, "np.ndarray"] = {} self._images: dict[str, np.ndarray] = {}
self._lock = Lock() self._lock = Lock()
self._updated = Event() self._updated = Event()
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@ -43,7 +37,7 @@ class PreviewBuffer():
""" bool: ``True`` when new images have been loaded into the preview buffer """ """ bool: ``True`` when new images have been loaded into the preview buffer """
return self._updated.is_set() return self._updated.is_set()
def add_image(self, name: str, image: "np.ndarray") -> None: def add_image(self, name: str, image: np.ndarray) -> None:
""" Add an image to the preview buffer in a thread safe way """ """ Add an image to the preview buffer in a thread safe way """
logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape) logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape)
with self._lock: with self._lock:
@ -51,7 +45,7 @@ class PreviewBuffer():
logger.debug("Added images: %s", list(self._images)) logger.debug("Added images: %s", list(self._images))
self._updated.set() self._updated.set()
def get_images(self) -> Generator[Tuple[str, "np.ndarray"], None, None]: def get_images(self) -> Generator[tuple[str, np.ndarray], None, None]:
""" Get the latest images from the preview buffer. When iterator is exhausted clears the """ Get the latest images from the preview buffer. When iterator is exhausted clears the
:attr:`updated` event. :attr:`updated` event.
@ -86,15 +80,15 @@ class PreviewBase(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
preview_buffer: PreviewBuffer, preview_buffer: PreviewBuffer,
triggers: Optional[TriggerType] = None) -> None: triggers: TriggerType | None = None) -> None:
logger.debug("Initializing %s parent (triggers: %s)", logger.debug("Initializing %s parent (triggers: %s)",
self.__class__.__name__, triggers) self.__class__.__name__, triggers)
self._triggers = triggers self._triggers = triggers
self._buffer = preview_buffer self._buffer = preview_buffer
self._keymaps: Dict[TriggerKeysType, TriggerNamesType] = dict(m="toggle_mask", self._keymaps: dict[TriggerKeysType, TriggerNamesType] = {"m": "toggle_mask",
r="refresh", "r": "refresh",
s="save", "s": "save",
enter="quit") "enter": "quit"}
self._title = "" self._title = ""
logger.debug("Initialized %s parent", self.__class__.__name__) logger.debug("Initialized %s parent", self.__class__.__name__)
@ -141,7 +135,7 @@ class PreviewCV(PreviewBase): # pylint:disable=too-few-public-methods
logger.debug("Unable to import Tkinter. Falling back to OpenCV") logger.debug("Unable to import Tkinter. Falling back to OpenCV")
super().__init__(preview_buffer, triggers=triggers) super().__init__(preview_buffer, triggers=triggers)
self._triggers: TriggerType = self._triggers self._triggers: TriggerType = self._triggers
self._windows: List[str] = [] self._windows: list[str] = []
self._lookup = {ord(key): val self._lookup = {ord(key): val
for key, val in self._keymaps.items() if key != "enter"} for key, val in self._keymaps.items() if key != "enter"}

View file

@ -4,24 +4,25 @@
If Tkinter is installed, then this will be used to manage the preview image, otherwise we If Tkinter is installed, then this will be used to manage the preview image, otherwise we
fallback to opencv's imshow fallback to opencv's imshow
""" """
from __future__ import annotations
import logging import logging
import os import os
import sys import sys
import tkinter as tk import tkinter as tk
import typing as T
from datetime import datetime from datetime import datetime
from platform import system from platform import system
from tkinter import ttk from tkinter import ttk
from math import ceil, floor from math import ceil, floor
from typing import cast, List, Optional, Tuple, TYPE_CHECKING
from PIL import Image, ImageTk from PIL import Image, ImageTk
import cv2 import cv2
from .preview_cv import PreviewBase, TriggerKeysType from .preview_cv import PreviewBase, TriggerKeysType
if TYPE_CHECKING: if T.TYPE_CHECKING:
import numpy as np import numpy as np
from .preview_cv import PreviewBuffer, TriggerType from .preview_cv import PreviewBuffer, TriggerType
@ -38,18 +39,18 @@ class _Taskbar():
taskbar: :class:`tkinter.ttk.Frame` or ``None`` taskbar: :class:`tkinter.ttk.Frame` or ``None``
None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI
""" """
def __init__(self, parent: tk.Frame, taskbar: Optional[ttk.Frame]) -> None: def __init__(self, parent: tk.Frame, taskbar: ttk.Frame | None) -> None:
logger.debug("Initializing %s (parent: '%s', taskbar: %s)", logger.debug("Initializing %s (parent: '%s', taskbar: %s)",
self.__class__.__name__, parent, taskbar) self.__class__.__name__, parent, taskbar)
self._is_standalone = taskbar is None self._is_standalone = taskbar is None
self._gui_mapped: List[tk.Widget] = [] self._gui_mapped: list[tk.Widget] = []
self._frame = tk.Frame(parent) if taskbar is None else taskbar self._frame = tk.Frame(parent) if taskbar is None else taskbar
self._min_max_scales = (20, 400) self._min_max_scales = (20, 400)
self._vars = dict(save=tk.BooleanVar(), self._vars = {"save": tk.BooleanVar(),
scale=tk.StringVar(), "scale": tk.StringVar(),
slider=tk.IntVar(), "slider": tk.IntVar(),
interpolator=tk.IntVar()) "interpolator": tk.IntVar()}
self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST), self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST),
("bicubic", cv2.INTER_CUBIC)] ("bicubic", cv2.INTER_CUBIC)]
self._scale = self._add_scale_combo() self._scale = self._add_scale_combo()
@ -261,7 +262,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
def __init__(self, def __init__(self,
parent: tk.Frame, parent: tk.Frame,
scale_var: tk.StringVar, scale_var: tk.StringVar,
screen_dimensions: Tuple[int, int], screen_dimensions: tuple[int, int],
is_standalone: bool) -> None: is_standalone: bool) -> None:
logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)", logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)",
self.__class__.__name__, parent, scale_var, screen_dimensions) self.__class__.__name__, parent, scale_var, screen_dimensions)
@ -272,7 +273,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
self._screen_dimensions = screen_dimensions self._screen_dimensions = screen_dimensions
self._var_scale = scale_var self._var_scale = scale_var
self._configure_scrollbars(frame) self._configure_scrollbars(frame)
self._image: Optional[ImageTk.PhotoImage] = None self._image: ImageTk.PhotoImage | None = None
self._image_id = self.create_image(self.width / 2, self._image_id = self.create_image(self.width / 2,
self.height / 2, self.height / 2,
anchor=tk.CENTER, anchor=tk.CENTER,
@ -400,8 +401,8 @@ class _Image():
logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)", logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)",
self.__class__.__name__, save_variable, is_standalone) self.__class__.__name__, save_variable, is_standalone)
self._is_standalone = is_standalone self._is_standalone = is_standalone
self._source: Optional["np.ndarray"] = None self._source: np.ndarray | None = None
self._display: Optional[ImageTk.PhotoImage] = None self._display: ImageTk.PhotoImage | None = None
self._scale = 1.0 self._scale = 1.0
self._interpolation = cv2.INTER_NEAREST self._interpolation = cv2.INTER_NEAREST
@ -416,7 +417,7 @@ class _Image():
return self._display return self._display
@property @property
def source(self) -> "np.ndarray": def source(self) -> np.ndarray:
""" :class:`PIL.Image.Image`: The current source preview image """ """ :class:`PIL.Image.Image`: The current source preview image """
assert self._source is not None assert self._source is not None
return self._source return self._source
@ -426,7 +427,7 @@ class _Image():
"""int: The current display scale as a percentage of original image size """ """int: The current display scale as a percentage of original image size """
return int(self._scale * 100) return int(self._scale * 100)
def set_source_image(self, name: str, image: "np.ndarray") -> None: def set_source_image(self, name: str, image: np.ndarray) -> None:
""" Set the source image to :attr:`source` """ Set the source image to :attr:`source`
Parameters Parameters
@ -542,7 +543,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
self._taskbar = taskbar self._taskbar = taskbar
self._image = image self._image = image
self._drag_data: List[float] = [0., 0.] self._drag_data: list[float] = [0., 0.]
self._set_mouse_bindings() self._set_mouse_bindings()
self._set_key_bindings(is_standalone) self._set_key_bindings(is_standalone)
logger.debug("Initialized %s", self.__class__.__name__,) logger.debug("Initialized %s", self.__class__.__name__,)
@ -604,7 +605,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
The key press event The key press event
""" """
move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview
visible = (move_axis()[1] - move_axis()[0]) visible = move_axis()[1] - move_axis()[0]
amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25 amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25
logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore
"amount: %s)", move_axis, visible, amount) "amount: %s)", move_axis, visible, amount)
@ -671,10 +672,10 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
Default: `None` Default: `None`
""" """
def __init__(self, def __init__(self,
preview_buffer: "PreviewBuffer", preview_buffer: PreviewBuffer,
parent: Optional[tk.Widget] = None, parent: tk.Widget | None = None,
taskbar: Optional[ttk.Frame] = None, taskbar: ttk.Frame | None = None,
triggers: Optional["TriggerType"] = None) -> None: triggers: TriggerType | None = None) -> None:
logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent) logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent)
super().__init__(preview_buffer, triggers=triggers) super().__init__(preview_buffer, triggers=triggers)
self._is_standalone = parent is None self._is_standalone = parent is None
@ -745,7 +746,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
logger.info(" Save Preview: Ctrl+s") logger.info(" Save Preview: Ctrl+s")
logger.info("---------------------------------------------------") logger.info("---------------------------------------------------")
def _get_geometry(self) -> Tuple[int, int]: def _get_geometry(self) -> tuple[int, int]:
""" Obtain the geometry of the current screen (standalone) or the dimensions of the widget """ Obtain the geometry of the current screen (standalone) or the dimensions of the widget
holding the preview window (GUI). holding the preview window (GUI).
@ -780,7 +781,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
half_screen = tuple(x // 2 for x in self._screen_dimensions) half_screen = tuple(x // 2 for x in self._screen_dimensions)
min_scales = (half_screen[0] / self._image.source.shape[1], min_scales = (half_screen[0] / self._image.source.shape[1],
half_screen[1] / self._image.source.shape[0]) half_screen[1] / self._image.source.shape[0])
min_scale = min(1.0, min(min_scales)) min_scale = min(1.0, *min_scales)
min_scale = (ceil(min_scale * 10)) * 10 min_scale = (ceil(min_scale * 10)) * 10
eight_screen = tuple(x * 8 for x in self._screen_dimensions) eight_screen = tuple(x * 8 for x in self._screen_dimensions)
@ -884,7 +885,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
if self._triggers is None: # Don't need triggers for GUI if self._triggers is None: # Don't need triggers for GUI
return return
keypress = "enter" if event.keysym == "Return" else event.keysym keypress = "enter" if event.keysym == "Return" else event.keysym
key = cast(TriggerKeysType, keypress) key = T.cast(TriggerKeysType, keypress)
logger.debug("Processing keypress '%s'", key) logger.debug("Processing keypress '%s'", key)
if key == "r": if key == "r":
print("") # Let log print on different line from loss output print("") # Let log print on different line from loss output

View file

@ -1,11 +1,12 @@
#!/usr/bin python3 #!/usr/bin python3
""" Utilities available across all scripts """ """ Utilities available across all scripts """
from __future__ import annotations
import json import json
import logging import logging
import os import os
import sys import sys
import tkinter as tk import tkinter as tk
import typing as T
import warnings import warnings
import zipfile import zipfile
@ -14,18 +15,12 @@ from re import finditer
from socket import timeout as socket_timeout, error as socket_error from socket import timeout as socket_timeout, error as socket_error
from threading import get_ident from threading import get_ident
from time import time from time import time
from typing import cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from urllib import request, error as urlliberror from urllib import request, error as urlliberror
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if TYPE_CHECKING:
from http.client import HTTPResponse from http.client import HTTPResponse
# Global variables # Global variables
@ -34,8 +29,8 @@ _image_extensions = [ # pylint:disable=invalid-name
_video_extensions = [ # pylint:disable=invalid-name _video_extensions = [ # pylint:disable=invalid-name
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv", ".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
".ts", ".vob"] ".ts", ".vob"]
_TF_VERS: Optional[Tuple[int, int]] = None _TF_VERS: tuple[int, int] | None = None
ValidBackends = Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"] ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
class _Backend(): # pylint:disable=too-few-public-methods class _Backend(): # pylint:disable=too-few-public-methods
@ -44,7 +39,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
If file doesn't exist and a variable hasn't been set, create the config file. """ If file doesn't exist and a variable hasn't been set, create the config file. """
def __init__(self) -> None: def __init__(self) -> None:
self._backends: Dict[str, ValidBackends] = {"1": "cpu", self._backends: dict[str, ValidBackends] = {"1": "cpu",
"2": "directml", "2": "directml",
"3": "nvidia", "3": "nvidia",
"4": "apple_silicon", "4": "apple_silicon",
@ -78,9 +73,9 @@ class _Backend(): # pylint:disable=too-few-public-methods
""" """
# Check if environment variable is set, if so use that # Check if environment variable is set, if so use that
if "FACESWAP_BACKEND" in os.environ: if "FACESWAP_BACKEND" in os.environ:
fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower()) fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
assert fs_backend in get_args(ValidBackends), ( assert fs_backend in T.get_args(ValidBackends), (
f"Faceswap backend must be one of {get_args(ValidBackends)}") f"Faceswap backend must be one of {T.get_args(ValidBackends)}")
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}") print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
return fs_backend return fs_backend
# Intercept for sphinx docs build # Intercept for sphinx docs build
@ -163,11 +158,11 @@ def set_backend(backend: str) -> None:
>>> set_backend("nvidia") >>> set_backend("nvidia")
""" """
global _FS_BACKEND # pylint:disable=global-statement global _FS_BACKEND # pylint:disable=global-statement
backend = cast(ValidBackends, backend.lower()) backend = T.cast(ValidBackends, backend.lower())
_FS_BACKEND = backend _FS_BACKEND = backend
def get_tf_version() -> Tuple[int, int]: def get_tf_version() -> tuple[int, int]:
""" Obtain the major. minor version of currently installed Tensorflow. """ Obtain the major. minor version of currently installed Tensorflow.
Returns Returns
@ -179,7 +174,7 @@ def get_tf_version() -> Tuple[int, int]:
------- -------
>>> from lib.utils import get_tf_version >>> from lib.utils import get_tf_version
>>> get_tf_version() >>> get_tf_version()
(2, 9) (2, 10)
""" """
global _TF_VERS # pylint:disable=global-statement global _TF_VERS # pylint:disable=global-statement
if _TF_VERS is None: if _TF_VERS is None:
@ -225,7 +220,7 @@ def get_folder(path: str, make_folder: bool = True) -> str:
return path return path
def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]: def get_image_paths(directory: str, extension: str | None = None) -> list[str]:
""" Gets the image paths from a given directory. """ Gets the image paths from a given directory.
The function searches for files with the specified extension(s) in the given directory, and The function searches for files with the specified extension(s) in the given directory, and
@ -274,7 +269,7 @@ def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str
return dir_contents return dir_contents
def get_dpi() -> Optional[float]: def get_dpi() -> float | None:
""" Gets the DPI (dots per inch) of the display screen. """ Gets the DPI (dots per inch) of the display screen.
Returns Returns
@ -338,7 +333,7 @@ def convert_to_secs(*args: int) -> int:
return retval return retval
def full_path_split(path: str) -> List[str]: def full_path_split(path: str) -> list[str]:
""" Split a file path into all of its parts. """ Split a file path into all of its parts.
Parameters Parameters
@ -360,7 +355,7 @@ def full_path_split(path: str) -> List[str]:
['relative', 'path', 'to', 'file.txt']] ['relative', 'path', 'to', 'file.txt']]
""" """
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
allparts: List[str] = [] allparts: list[str] = []
while True: while True:
parts = os.path.split(path) parts = os.path.split(path)
if parts[0] == path: # sentinel for absolute paths if parts[0] == path: # sentinel for absolute paths
@ -410,7 +405,7 @@ def set_system_verbosity(log_level: str):
warnings.simplefilter(action='ignore', category=warncat) warnings.simplefilter(action='ignore', category=warncat)
def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None: def deprecation_warning(function: str, additional_info: str | None = None) -> None:
""" Log a deprecation warning message. """ Log a deprecation warning message.
This function logs a warning message to indicate that the specified function has been This function logs a warning message to indicate that the specified function has been
@ -436,7 +431,7 @@ def deprecation_warning(function: str, additional_info: Optional[str] = None) ->
logger.warning(msg) logger.warning(msg)
def camel_case_split(identifier: str) -> List[str]: def camel_case_split(identifier: str) -> list[str]:
""" Split a camelCase string into a list of its individual parts """ Split a camelCase string into a list of its individual parts
Parameters Parameters
@ -541,7 +536,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
>>> model_downloader = GetModel("s3fd_keras_v2.h5", 11) >>> model_downloader = GetModel("s3fd_keras_v2.h5", 11)
""" """
def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None: def __init__(self, model_filename: str | list[str], git_model_id: int) -> None:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
if not isinstance(model_filename, list): if not isinstance(model_filename, list):
model_filename = [model_filename] model_filename = [model_filename]
@ -576,7 +571,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
return retval return retval
@property @property
def model_path(self) -> Union[str, List[str]]: def model_path(self) -> str | list[str]:
""" str or list[str]: The model path(s) in the cache folder. """ str or list[str]: The model path(s) in the cache folder.
Example Example
@ -587,7 +582,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
'/path/to/s3fd_keras_v2.h5' '/path/to/s3fd_keras_v2.h5'
""" """
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename] paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths retval: str | list[str] = paths[0] if len(paths) == 1 else paths
self.logger.trace(retval) # type:ignore[attr-defined] self.logger.trace(retval) # type:ignore[attr-defined]
return retval return retval
@ -662,7 +657,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
self._url_download, self._cache_dir) self._url_download, self._cache_dir)
sys.exit(1) sys.exit(1)
def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None: def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None:
""" Write the model zip file to disk. """ Write the model zip file to disk.
Parameters Parameters
@ -762,8 +757,8 @@ class DebugTimes():
""" """
def __init__(self, def __init__(self,
show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None: show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None:
self._times: Dict[str, List[float]] = {} self._times: dict[str, list[float]] = {}
self._steps: Dict[str, float] = {} self._steps: dict[str, float] = {}
self._interval = 1 self._interval = 1
self._display = {"min": show_min, "mean": show_mean, "max": show_max} self._display = {"min": show_min, "mean": show_mean, "max": show_max}

View file

@ -8,7 +8,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: PACKAGE VERSION\n" "Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-06-11 23:20+0100\n" "POT-Creation-Date: 2023-06-25 13:39+0100\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n" "Language-Team: LANGUAGE <LL@li.org>\n"
@ -227,7 +227,8 @@ msgid ""
msgstr "" msgstr ""
#: plugins/train/_config.py:198 plugins/train/_config.py:223 #: plugins/train/_config.py:198 plugins/train/_config.py:223
#: plugins/train/_config.py:238 plugins/train/_config.py:265 #: plugins/train/_config.py:238 plugins/train/_config.py:256
#: plugins/train/_config.py:290
msgid "optimizer" msgid "optimizer"
msgstr "" msgstr ""
@ -274,21 +275,40 @@ msgid ""
"epsilon to 0.001 (1e-3)." "epsilon to 0.001 (1e-3)."
msgstr "" msgstr ""
#: plugins/train/_config.py:258 #: plugins/train/_config.py:262
msgid "" msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the " "When to save the Optimizer Weights. Saving the optimizer weights is not "
"gradient weights and adjusts the normalization value dynamically to fit the " "necessary and will increase the model file size 3x (and by extension the "
"data. Can help prevent NaNs and improve model optimization at the expense of " "amount of time it takes to save the model). However, it can be useful to "
"VRAM. Ref: AutoClip: Adaptive Gradient Clipping for Source Separation " "save these weights if you want to guarantee that a resumed model carries off "
"Networks https://arxiv.org/abs/2007.14469" "exactly from where it left off, rather than spending a few hundred "
"iterations catching up.\n"
"\t never - Don't save optimizer weights.\n"
"\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have "
"the last saved optimizer state in your model file.\n"
"\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."
msgstr "" msgstr ""
#: plugins/train/_config.py:271 plugins/train/_config.py:283 #: plugins/train/_config.py:283
#: plugins/train/_config.py:297 plugins/train/_config.py:314 msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
"and adjusts the normalization value dynamically to fit the data. Can help "
"prevent NaNs and improve model optimization at the expense of VRAM. Ref: "
"AutoClip: Adaptive Gradient Clipping for Source Separation Networks https://"
"arxiv.org/abs/2007.14469"
msgstr ""
#: plugins/train/_config.py:296 plugins/train/_config.py:308
#: plugins/train/_config.py:322 plugins/train/_config.py:339
msgid "network" msgid "network"
msgstr "" msgstr ""
#: plugins/train/_config.py:273 #: plugins/train/_config.py:298
msgid "" msgid ""
"Use reflection padding rather than zero padding with convolutions. Each " "Use reflection padding rather than zero padding with convolutions. Each "
"convolution must pad the image boundaries to maintain the proper sizing. " "convolution must pad the image boundaries to maintain the proper sizing. "
@ -297,21 +317,21 @@ msgid ""
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" "\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
msgstr "" msgstr ""
#: plugins/train/_config.py:286 #: plugins/train/_config.py:311
msgid "" msgid ""
"Enable the Tensorflow GPU 'allow_growth' configuration " "Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
"option. This option prevents Tensorflow from allocating all of the GPU VRAM " "prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
"at launch but can lead to higher VRAM fragmentation and slower performance. " "lead to higher VRAM fragmentation and slower performance. Should only be "
"Should only be enabled if you are receiving errors regarding 'cuDNN fails to " "enabled if you are receiving errors regarding 'cuDNN fails to initialize' "
"initialize' when commencing training." "when commencing training."
msgstr "" msgstr ""
#: plugins/train/_config.py:299 #: plugins/train/_config.py:324
msgid "" msgid ""
"NVIDIA GPUs can run operations in float16 faster than in " "NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
"float32. Mixed precision allows you to use a mix of float16 with float32, to " "precision allows you to use a mix of float16 with float32, to get the "
"get the performance benefits from float16 and the numeric stability benefits " "performance benefits from float16 and the numeric stability benefits from "
"from float32.\n" "float32.\n"
"\n" "\n"
"This is untested on DirectML backend, but will run on most Nvidia models. it " "This is untested on DirectML backend, but will run on most Nvidia models. it "
"will only speed up training on more recent GPUs. Those with compute " "will only speed up training on more recent GPUs. Those with compute "
@ -322,7 +342,7 @@ msgid ""
"the most benefit." "the most benefit."
msgstr "" msgstr ""
#: plugins/train/_config.py:316 #: plugins/train/_config.py:341
msgid "" msgid ""
"If a 'NaN' is generated in the model, this means that the model has " "If a 'NaN' is generated in the model, this means that the model has "
"corrupted and the model is likely to start deteriorating from this point on. " "corrupted and the model is likely to start deteriorating from this point on. "
@ -331,11 +351,11 @@ msgid ""
"rescue your model." "rescue your model."
msgstr "" msgstr ""
#: plugins/train/_config.py:329 #: plugins/train/_config.py:354
msgid "convert" msgid "convert"
msgstr "" msgstr ""
#: plugins/train/_config.py:331 #: plugins/train/_config.py:356
msgid "" msgid ""
"[GPU Only]. The number of faces to feed through the model at once when " "[GPU Only]. The number of faces to feed through the model at once when "
"running the Convert process.\n" "running the Convert process.\n"
@ -345,27 +365,27 @@ msgid ""
"size." "size."
msgstr "" msgstr ""
#: plugins/train/_config.py:350 #: plugins/train/_config.py:375
msgid "" msgid ""
"Loss configuration options\n" "Loss configuration options\n"
"Loss is the mechanism by which a Neural Network judges how well it thinks " "Loss is the mechanism by which a Neural Network judges how well it thinks "
"that it is recreating a face." "that it is recreating a face."
msgstr "" msgstr ""
#: plugins/train/_config.py:357 plugins/train/_config.py:369 #: plugins/train/_config.py:382 plugins/train/_config.py:394
#: plugins/train/_config.py:382 plugins/train/_config.py:402 #: plugins/train/_config.py:407 plugins/train/_config.py:427
#: plugins/train/_config.py:414 plugins/train/_config.py:434 #: plugins/train/_config.py:439 plugins/train/_config.py:459
#: plugins/train/_config.py:446 plugins/train/_config.py:466 #: plugins/train/_config.py:471 plugins/train/_config.py:491
#: plugins/train/_config.py:482 plugins/train/_config.py:498 #: plugins/train/_config.py:507 plugins/train/_config.py:523
#: plugins/train/_config.py:515 #: plugins/train/_config.py:540
msgid "loss" msgid "loss"
msgstr "" msgstr ""
#: plugins/train/_config.py:361 #: plugins/train/_config.py:386
msgid "The loss function to use." msgid "The loss function to use."
msgstr "" msgstr ""
#: plugins/train/_config.py:373 #: plugins/train/_config.py:398
msgid "" msgid ""
"The second loss function to use. If using a structural based loss (such as " "The second loss function to use. If using a structural based loss (such as "
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " "SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
@ -373,7 +393,7 @@ msgid ""
"function with the loss_weight_2 option." "function with the loss_weight_2 option."
msgstr "" msgstr ""
#: plugins/train/_config.py:388 #: plugins/train/_config.py:413
msgid "" msgid ""
"The amount of weight to apply to the second loss function.\n" "The amount of weight to apply to the second loss function.\n"
"\n" "\n"
@ -391,13 +411,13 @@ msgid ""
"\t 0 - Disables the second loss function altogether." "\t 0 - Disables the second loss function altogether."
msgstr "" msgstr ""
#: plugins/train/_config.py:406 #: plugins/train/_config.py:431
msgid "" msgid ""
"The third loss function to use. You can adjust the weighting of this loss " "The third loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option." "function with the loss_weight_3 option."
msgstr "" msgstr ""
#: plugins/train/_config.py:420 #: plugins/train/_config.py:445
msgid "" msgid ""
"The amount of weight to apply to the third loss function.\n" "The amount of weight to apply to the third loss function.\n"
"\n" "\n"
@ -415,13 +435,13 @@ msgid ""
"\t 0 - Disables the third loss function altogether." "\t 0 - Disables the third loss function altogether."
msgstr "" msgstr ""
#: plugins/train/_config.py:438 #: plugins/train/_config.py:463
msgid "" msgid ""
"The fourth loss function to use. You can adjust the weighting of this loss " "The fourth loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option." "function with the loss_weight_3 option."
msgstr "" msgstr ""
#: plugins/train/_config.py:452 #: plugins/train/_config.py:477
msgid "" msgid ""
"The amount of weight to apply to the fourth loss function.\n" "The amount of weight to apply to the fourth loss function.\n"
"\n" "\n"
@ -439,7 +459,7 @@ msgid ""
"\t 0 - Disables the fourth loss function altogether." "\t 0 - Disables the fourth loss function altogether."
msgstr "" msgstr ""
#: plugins/train/_config.py:471 #: plugins/train/_config.py:496
msgid "" msgid ""
"The loss function to use when learning a mask.\n" "The loss function to use when learning a mask.\n"
"\t MAE - Mean absolute error will guide reconstructions of each pixel " "\t MAE - Mean absolute error will guide reconstructions of each pixel "
@ -451,7 +471,7 @@ msgid ""
"susceptible to outliers and typically produces slightly blurrier results." "susceptible to outliers and typically produces slightly blurrier results."
msgstr "" msgstr ""
#: plugins/train/_config.py:488 #: plugins/train/_config.py:513
msgid "" msgid ""
"The amount of priority to give to the eyes.\n" "The amount of priority to give to the eyes.\n"
"\n" "\n"
@ -464,7 +484,7 @@ msgid ""
"NB: Penalized Mask Loss must be enable to use this option." "NB: Penalized Mask Loss must be enable to use this option."
msgstr "" msgstr ""
#: plugins/train/_config.py:504 #: plugins/train/_config.py:529
msgid "" msgid ""
"The amount of priority to give to the mouth.\n" "The amount of priority to give to the mouth.\n"
"\n" "\n"
@ -477,7 +497,7 @@ msgid ""
"NB: Penalized Mask Loss must be enable to use this option." "NB: Penalized Mask Loss must be enable to use this option."
msgstr "" msgstr ""
#: plugins/train/_config.py:517 #: plugins/train/_config.py:542
msgid "" msgid ""
"Image loss function is weighted by mask presence. For areas of the image " "Image loss function is weighted by mask presence. For areas of the image "
"without the facial mask, reconstruction errors will be ignored while the " "without the facial mask, reconstruction errors will be ignored while the "
@ -485,12 +505,12 @@ msgid ""
"attention on the core face area." "attention on the core face area."
msgstr "" msgstr ""
#: plugins/train/_config.py:528 plugins/train/_config.py:570 #: plugins/train/_config.py:553 plugins/train/_config.py:595
#: plugins/train/_config.py:584 plugins/train/_config.py:593 #: plugins/train/_config.py:609 plugins/train/_config.py:618
msgid "mask" msgid "mask"
msgstr "" msgstr ""
#: plugins/train/_config.py:531 #: plugins/train/_config.py:556
msgid "" msgid ""
"The mask to be used for training. If you have selected 'Learn Mask' or " "The mask to be used for training. If you have selected 'Learn Mask' or "
"'Penalized Mask Loss' you must select a value other than 'none'. The " "'Penalized Mask Loss' you must select a value other than 'none'. The "
@ -528,7 +548,7 @@ msgid ""
"performance." "performance."
msgstr "" msgstr ""
#: plugins/train/_config.py:572 #: plugins/train/_config.py:597
msgid "" msgid ""
"Apply gaussian blur to the mask input. This has the effect of smoothing the " "Apply gaussian blur to the mask input. This has the effect of smoothing the "
"edges of the mask, which can help with poorly calculated masks and give less " "edges of the mask, which can help with poorly calculated masks and give less "
@ -538,13 +558,13 @@ msgid ""
"number." "number."
msgstr "" msgstr ""
#: plugins/train/_config.py:586 #: plugins/train/_config.py:611
msgid "" msgid ""
"Sets pixels that are near white to white and near black to black. Set to 0 " "Sets pixels that are near white to white and near black to black. Set to 0 "
"for off." "for off."
msgstr "" msgstr ""
#: plugins/train/_config.py:595 #: plugins/train/_config.py:620
msgid "" msgid ""
"Dedicate a portion of the model to learning how to duplicate the input mask. " "Dedicate a portion of the model to learning how to duplicate the input mask. "
"Increases VRAM usage in exchange for learning a quick ability to try to " "Increases VRAM usage in exchange for learning a quick ability to try to "

View file

@ -7,8 +7,8 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: \n" "Project-Id-Version: \n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-06-11 23:20+0100\n" "POT-Creation-Date: 2023-06-25 13:39+0100\n"
"PO-Revision-Date: 2023-06-20 17:06+0100\n" "PO-Revision-Date: 2023-06-25 13:42+0100\n"
"Last-Translator: \n" "Last-Translator: \n"
"Language-Team: \n" "Language-Team: \n"
"Language: ru_RU\n" "Language: ru_RU\n"
@ -354,7 +354,8 @@ msgstr ""
"повлияет только на запуск новой модели." "повлияет только на запуск новой модели."
#: plugins/train/_config.py:198 plugins/train/_config.py:223 #: plugins/train/_config.py:198 plugins/train/_config.py:223
#: plugins/train/_config.py:238 plugins/train/_config.py:265 #: plugins/train/_config.py:238 plugins/train/_config.py:256
#: plugins/train/_config.py:290
msgid "optimizer" msgid "optimizer"
msgstr "оптимизатор" msgstr "оптимизатор"
@ -435,7 +436,41 @@ msgstr ""
"Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе " "Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе "
"значения \"-3\" эпсилон будет равен 0,001 (1e-3)." "значения \"-3\" эпсилон будет равен 0,001 (1e-3)."
#: plugins/train/_config.py:258 #: plugins/train/_config.py:262
msgid ""
"When to save the Optimizer Weights. Saving the optimizer weights is not "
"necessary and will increase the model file size 3x (and by extension the "
"amount of time it takes to save the model). However, it can be useful to "
"save these weights if you want to guarantee that a resumed model carries off "
"exactly from where it left off, rather than spending a few hundred "
"iterations catching up.\n"
"\t never - Don't save optimizer weights.\n"
"\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have "
"the last saved optimizer state in your model file.\n"
"\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."
msgstr ""
"Когда сохранять веса оптимизатора. Сохранение весов оптимизатора не является "
"необходимым и увеличит размер файла модели в 3 раза (и соответственно время, "
"необходимое для сохранения модели). Однако может быть полезно сохранить эти "
"веса, если вы хотите гарантировать, что возобновленная модель продолжит "
"работу именно с того места, где она остановилась, а не тратит несколько "
"сотен итераций на догонялки.\n"
"\t never - не сохранять веса оптимизатора.\n"
"\t always - сохранять веса оптимизатора при каждой итерации сохранения. "
"Сохранение модели займет больше времени из-за увеличенного размера файла, но "
"в файле модели всегда будет последнее сохраненное состояние оптимизатора.\n"
"\t exit - сохранять веса оптимизатора только при явном завершении модели. "
"Это может быть, когда модель активно останавливается или когда выполняются "
"целевые итерации. Примечание. Если сеанс обучения завершается по другой "
"причине (например, отключение питания, ошибка нехватки памяти, обнаружение "
"NaN), веса оптимизатора НЕ будут сохранены."
#: plugins/train/_config.py:283
msgid "" msgid ""
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights " "Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
"and adjusts the normalization value dynamically to fit the data. Can help " "and adjusts the normalization value dynamically to fit the data. Can help "
@ -449,12 +484,12 @@ msgstr ""
"ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping for Source " "ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping for Source "
"Separation Networks [ТОЛЬКО на английском] https://arxiv.org/abs/2007.14469" "Separation Networks [ТОЛЬКО на английском] https://arxiv.org/abs/2007.14469"
#: plugins/train/_config.py:271 plugins/train/_config.py:283 #: plugins/train/_config.py:296 plugins/train/_config.py:308
#: plugins/train/_config.py:297 plugins/train/_config.py:314 #: plugins/train/_config.py:322 plugins/train/_config.py:339
msgid "network" msgid "network"
msgstr "сеть" msgstr "сеть"
#: plugins/train/_config.py:273 #: plugins/train/_config.py:298
msgid "" msgid ""
"Use reflection padding rather than zero padding with convolutions. Each " "Use reflection padding rather than zero padding with convolutions. Each "
"convolution must pad the image boundaries to maintain the proper sizing. " "convolution must pad the image boundaries to maintain the proper sizing. "
@ -468,7 +503,7 @@ msgstr ""
"изображения.\n" "изображения.\n"
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" "\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
#: plugins/train/_config.py:286 #: plugins/train/_config.py:311
msgid "" msgid ""
"Enable the Tensorflow GPU 'allow_growth' configuration option. This option " "Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
"prevents Tensorflow from allocating all of the GPU VRAM at launch but can " "prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
@ -483,7 +518,7 @@ msgstr ""
"случае, если у вас появляются ошибки, рода 'cuDNN fails to initialize'(cuDNN " "случае, если у вас появляются ошибки, рода 'cuDNN fails to initialize'(cuDNN "
"не может инициализироваться) при начале тренировки." "не может инициализироваться) при начале тренировки."
#: plugins/train/_config.py:299 #: plugins/train/_config.py:324
msgid "" msgid ""
"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed " "NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
"precision allows you to use a mix of float16 with float32, to get the " "precision allows you to use a mix of float16 with float32, to get the "
@ -512,7 +547,7 @@ msgstr ""
"ускорение. В основном RTX видеокарты и позже предлагают самое большое " "ускорение. В основном RTX видеокарты и позже предлагают самое большое "
"ускорение." "ускорение."
#: plugins/train/_config.py:316 #: plugins/train/_config.py:341
msgid "" msgid ""
"If a 'NaN' is generated in the model, this means that the model has " "If a 'NaN' is generated in the model, this means that the model has "
"corrupted and the model is likely to start deteriorating from this point on. " "corrupted and the model is likely to start deteriorating from this point on. "
@ -526,11 +561,11 @@ msgstr ""
"NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет " "NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет "
"возможность спасти вашу модель." "возможность спасти вашу модель."
#: plugins/train/_config.py:329 #: plugins/train/_config.py:354
msgid "convert" msgid "convert"
msgstr "конвертирование" msgstr "конвертирование"
#: plugins/train/_config.py:331 #: plugins/train/_config.py:356
msgid "" msgid ""
"[GPU Only]. The number of faces to feed through the model at once when " "[GPU Only]. The number of faces to feed through the model at once when "
"running the Convert process.\n" "running the Convert process.\n"
@ -546,7 +581,7 @@ msgstr ""
"конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда " "конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда "
"стоит снизить размер пачки." "стоит снизить размер пачки."
#: plugins/train/_config.py:350 #: plugins/train/_config.py:375
msgid "" msgid ""
"Loss configuration options\n" "Loss configuration options\n"
"Loss is the mechanism by which a Neural Network judges how well it thinks " "Loss is the mechanism by which a Neural Network judges how well it thinks "
@ -556,20 +591,20 @@ msgstr ""
"Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она " "Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она "
"воспроизводит лицо." "воспроизводит лицо."
#: plugins/train/_config.py:357 plugins/train/_config.py:369 #: plugins/train/_config.py:382 plugins/train/_config.py:394
#: plugins/train/_config.py:382 plugins/train/_config.py:402 #: plugins/train/_config.py:407 plugins/train/_config.py:427
#: plugins/train/_config.py:414 plugins/train/_config.py:434 #: plugins/train/_config.py:439 plugins/train/_config.py:459
#: plugins/train/_config.py:446 plugins/train/_config.py:466 #: plugins/train/_config.py:471 plugins/train/_config.py:491
#: plugins/train/_config.py:482 plugins/train/_config.py:498 #: plugins/train/_config.py:507 plugins/train/_config.py:523
#: plugins/train/_config.py:515 #: plugins/train/_config.py:540
msgid "loss" msgid "loss"
msgstr "потери" msgstr "потери"
#: plugins/train/_config.py:361 #: plugins/train/_config.py:386
msgid "The loss function to use." msgid "The loss function to use."
msgstr "Какую функцию потерь стоит использовать." msgstr "Какую функцию потерь стоит использовать."
#: plugins/train/_config.py:373 #: plugins/train/_config.py:398
msgid "" msgid ""
"The second loss function to use. If using a structural based loss (such as " "The second loss function to use. If using a structural based loss (such as "
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " "SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
@ -581,7 +616,7 @@ msgstr ""
"регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес " "регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес "
"этой функции потерь с помощью параметра loss_weight_2." "этой функции потерь с помощью параметра loss_weight_2."
#: plugins/train/_config.py:388 #: plugins/train/_config.py:413
msgid "" msgid ""
"The amount of weight to apply to the second loss function.\n" "The amount of weight to apply to the second loss function.\n"
"\n" "\n"
@ -612,7 +647,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n" "4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь." "\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:406 #: plugins/train/_config.py:431
msgid "" msgid ""
"The third loss function to use. You can adjust the weighting of this loss " "The third loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option." "function with the loss_weight_3 option."
@ -620,7 +655,7 @@ msgstr ""
"Третья используемая функция потерь. Вы можете настроить вес этой функции " "Третья используемая функция потерь. Вы можете настроить вес этой функции "
"потерь с помощью параметра loss_weight_3." "потерь с помощью параметра loss_weight_3."
#: plugins/train/_config.py:420 #: plugins/train/_config.py:445
msgid "" msgid ""
"The amount of weight to apply to the third loss function.\n" "The amount of weight to apply to the third loss function.\n"
"\n" "\n"
@ -651,7 +686,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n" "4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь." "\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:438 #: plugins/train/_config.py:463
msgid "" msgid ""
"The fourth loss function to use. You can adjust the weighting of this loss " "The fourth loss function to use. You can adjust the weighting of this loss "
"function with the loss_weight_3 option." "function with the loss_weight_3 option."
@ -659,7 +694,7 @@ msgstr ""
"Четвертая используемая функция потерь. Вы можете настроить вес этой функции " "Четвертая используемая функция потерь. Вы можете настроить вес этой функции "
"потерь с помощью параметра 'loss_weight_4'." "потерь с помощью параметра 'loss_weight_4'."
#: plugins/train/_config.py:452 #: plugins/train/_config.py:477
msgid "" msgid ""
"The amount of weight to apply to the fourth loss function.\n" "The amount of weight to apply to the fourth loss function.\n"
"\n" "\n"
@ -690,7 +725,7 @@ msgstr ""
"4 раза перед добавлением к общей оценке потерь. \n" "4 раза перед добавлением к общей оценке потерь. \n"
"\t 0 - Полностью отключает четвертую функцию потерь." "\t 0 - Полностью отключает четвертую функцию потерь."
#: plugins/train/_config.py:471 #: plugins/train/_config.py:496
msgid "" msgid ""
"The loss function to use when learning a mask.\n" "The loss function to use when learning a mask.\n"
"\t MAE - Mean absolute error will guide reconstructions of each pixel " "\t MAE - Mean absolute error will guide reconstructions of each pixel "
@ -711,7 +746,7 @@ msgstr ""
"данных. Как среднее значение, оно чувствительно к выбросам и обычно дает " "данных. Как среднее значение, оно чувствительно к выбросам и обычно дает "
"немного более размытые результаты." "немного более размытые результаты."
#: plugins/train/_config.py:488 #: plugins/train/_config.py:513
msgid "" msgid ""
"The amount of priority to give to the eyes.\n" "The amount of priority to give to the eyes.\n"
"\n" "\n"
@ -731,7 +766,7 @@ msgstr ""
"\n" "\n"
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." "NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
#: plugins/train/_config.py:504 #: plugins/train/_config.py:529
msgid "" msgid ""
"The amount of priority to give to the mouth.\n" "The amount of priority to give to the mouth.\n"
"\n" "\n"
@ -751,7 +786,7 @@ msgstr ""
"\n" "\n"
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." "NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
#: plugins/train/_config.py:517 #: plugins/train/_config.py:542
msgid "" msgid ""
"Image loss function is weighted by mask presence. For areas of the image " "Image loss function is weighted by mask presence. For areas of the image "
"without the facial mask, reconstruction errors will be ignored while the " "without the facial mask, reconstruction errors will be ignored while the "
@ -763,12 +798,12 @@ msgstr ""
"время как область лица с маской является приоритетной. Может повысить общее " "время как область лица с маской является приоритетной. Может повысить общее "
"качество за счет концентрации внимания на основной области лица." "качество за счет концентрации внимания на основной области лица."
#: plugins/train/_config.py:528 plugins/train/_config.py:570 #: plugins/train/_config.py:553 plugins/train/_config.py:595
#: plugins/train/_config.py:584 plugins/train/_config.py:593 #: plugins/train/_config.py:609 plugins/train/_config.py:618
msgid "mask" msgid "mask"
msgstr "маска" msgstr "маска"
#: plugins/train/_config.py:531 #: plugins/train/_config.py:556
msgid "" msgid ""
"The mask to be used for training. If you have selected 'Learn Mask' or " "The mask to be used for training. If you have selected 'Learn Mask' or "
"'Penalized Mask Loss' you must select a value other than 'none'. The " "'Penalized Mask Loss' you must select a value other than 'none'. The "
@ -840,7 +875,7 @@ msgstr ""
"сообщества и для дальнейшего описания нуждается в тестировании. Профильные " "сообщества и для дальнейшего описания нуждается в тестировании. Профильные "
"лица могут иметь низкую производительность." "лица могут иметь низкую производительность."
#: plugins/train/_config.py:572 #: plugins/train/_config.py:597
msgid "" msgid ""
"Apply gaussian blur to the mask input. This has the effect of smoothing the " "Apply gaussian blur to the mask input. This has the effect of smoothing the "
"edges of the mask, which can help with poorly calculated masks and give less " "edges of the mask, which can help with poorly calculated masks and give less "
@ -856,7 +891,7 @@ msgstr ""
"должно быть нечетным, если передано четное число, то оно будет округлено до " "должно быть нечетным, если передано четное число, то оно будет округлено до "
"следующего нечетного числа." "следующего нечетного числа."
#: plugins/train/_config.py:586 #: plugins/train/_config.py:611
msgid "" msgid ""
"Sets pixels that are near white to white and near black to black. Set to 0 " "Sets pixels that are near white to white and near black to black. Set to 0 "
"for off." "for off."
@ -864,7 +899,7 @@ msgstr ""
"Устанавливает пиксели, которые почти белые - в белые и которые почти черные " "Устанавливает пиксели, которые почти белые - в белые и которые почти черные "
"- в черные. Установите 0, чтобы выключить." "- в черные. Установите 0, чтобы выключить."
#: plugins/train/_config.py:595 #: plugins/train/_config.py:620
msgid "" msgid ""
"Dedicate a portion of the model to learning how to duplicate the input mask. " "Dedicate a portion of the model to learning how to duplicate the input mask. "
"Increases VRAM usage in exchange for learning a quick ability to try to " "Increases VRAM usage in exchange for learning a quick ability to try to "

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Plugin to blend the edges of the face between the swap and the original face. """ """ Plugin to blend the edges of the face between the swap and the original face. """
import logging import logging
import sys import typing as T
from typing import List, Optional, Tuple
import cv2 import cv2
import numpy as np import numpy as np
@ -11,12 +10,6 @@ from lib.align import BlurMask, DetectedFace
from lib.config import FaceswapConfig from lib.config import FaceswapConfig
from plugins.convert._config import Config from plugins.convert._config import Config
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,8 +37,8 @@ class Mask(): # pylint:disable=too-few-public-methods
mask_type: str, mask_type: str,
output_size: int, output_size: int,
coverage_ratio: float, coverage_ratio: float,
configfile: Optional[str] = None, configfile: str | None = None,
config: Optional[FaceswapConfig] = None) -> None: config: FaceswapConfig | None = None) -> None:
logger.debug("Initializing %s: (mask_type: '%s', output_size: %s, coverage_ratio: %s, " logger.debug("Initializing %s: (mask_type: '%s', output_size: %s, coverage_ratio: %s, "
"configfile: %s, config: %s)", self.__class__.__name__, mask_type, "configfile: %s, config: %s)", self.__class__.__name__, mask_type,
coverage_ratio, output_size, configfile, config) coverage_ratio, output_size, configfile, config)
@ -61,8 +54,8 @@ class Mask(): # pylint:disable=too-few-public-methods
self._do_erode = any(amount != 0 for amount in self._erodes) self._do_erode = any(amount != 0 for amount in self._erodes)
def _set_config(self, def _set_config(self,
configfile: Optional[str], configfile: str | None,
config: Optional[FaceswapConfig]) -> dict: config: FaceswapConfig | None) -> dict:
""" Set the correct configuration for the plugin based on whether a config file """ Set the correct configuration for the plugin based on whether a config file
or pre-loaded config has been passed in. or pre-loaded config has been passed in.
@ -123,8 +116,8 @@ class Mask(): # pylint:disable=too-few-public-methods
detected_face: DetectedFace, detected_face: DetectedFace,
source_offset: np.ndarray, source_offset: np.ndarray,
target_offset: np.ndarray, target_offset: np.ndarray,
centering: Literal["legacy", "face", "head"], centering: T.Literal["legacy", "face", "head"],
predicted_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: predicted_mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
""" Obtain the requested mask type and perform any defined mask manipulations. """ Obtain the requested mask type and perform any defined mask manipulations.
Parameters Parameters
@ -171,8 +164,8 @@ class Mask(): # pylint:disable=too-few-public-methods
def _get_mask(self, def _get_mask(self,
detected_face: DetectedFace, detected_face: DetectedFace,
predicted_mask: Optional[np.ndarray], predicted_mask: np.ndarray | None,
centering: Literal["legacy", "face", "head"], centering: T.Literal["legacy", "face", "head"],
source_offset: np.ndarray, source_offset: np.ndarray,
target_offset: np.ndarray) -> np.ndarray: target_offset: np.ndarray) -> np.ndarray:
""" Return the requested mask with any requested blurring applied. """ Return the requested mask with any requested blurring applied.
@ -229,7 +222,7 @@ class Mask(): # pylint:disable=too-few-public-methods
def _get_stored_mask(self, def _get_stored_mask(self,
detected_face: DetectedFace, detected_face: DetectedFace,
centering: Literal["legacy", "face", "head"], centering: T.Literal["legacy", "face", "head"],
source_offset: np.ndarray, source_offset: np.ndarray,
target_offset: np.ndarray) -> np.ndarray: target_offset: np.ndarray) -> np.ndarray:
""" get the requested stored mask from the detected face object. """ get the requested stored mask from the detected face object.
@ -303,7 +296,7 @@ class Mask(): # pylint:disable=too-few-public-methods
return eroded[..., None] return eroded[..., None]
def _get_erosion_kernels(self, mask: np.ndarray) -> List[np.ndarray]: def _get_erosion_kernels(self, mask: np.ndarray) -> list[np.ndarray]:
""" Get the erosion kernels for each of the center, left, top right and bottom erosions. """ Get the erosion kernels for each of the center, left, top right and bottom erosions.
An approximation is made based on the number of positive pixels within the mask to create An approximation is made based on the number of positive pixels within the mask to create

View file

@ -4,8 +4,7 @@
import logging import logging
import os import os
import re import re
import typing as T
from typing import Any, List, Optional
import numpy as np import numpy as np
@ -14,7 +13,7 @@ from plugins.convert._config import Config
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_config(plugin_name: str, configfile: Optional[str] = None) -> dict: def get_config(plugin_name: str, configfile: str | None = None) -> dict:
""" Obtain the configuration settings for the writer plugin. """ Obtain the configuration settings for the writer plugin.
Parameters Parameters
@ -44,7 +43,7 @@ class Output():
The full path to a custom configuration ini file. If ``None`` is passed The full path to a custom configuration ini file. If ``None`` is passed
then the file is loaded from the default location. Default: ``None``. then the file is loaded from the default location. Default: ``None``.
""" """
def __init__(self, output_folder: str, configfile: Optional[str] = None) -> None: def __init__(self, output_folder: str, configfile: str | None = None) -> None:
logger.debug("Initializing %s: (output_folder: '%s')", logger.debug("Initializing %s: (output_folder: '%s')",
self.__class__.__name__, output_folder) self.__class__.__name__, output_folder)
self.config: dict = get_config(".".join(self.__module__.split(".")[-2:]), self.config: dict = get_config(".".join(self.__module__.split(".")[-2:]),
@ -69,7 +68,7 @@ class Output():
retval = hasattr(self, "frame_order") retval = hasattr(self, "frame_order")
return retval return retval
def output_filename(self, filename: str, separate_mask: bool = False) -> List[str]: def output_filename(self, filename: str, separate_mask: bool = False) -> list[str]:
""" Obtain the full path for the output file, including the correct extension, for the """ Obtain the full path for the output file, including the correct extension, for the
given input filename. given input filename.
@ -124,7 +123,7 @@ class Output():
logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore
logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore
def write(self, filename: str, image: Any) -> None: def write(self, filename: str, image: T.Any) -> None:
""" Override for specific frame writing method. """ Override for specific frame writing method.
Parameters Parameters
@ -137,7 +136,7 @@ class Output():
""" """
raise NotImplementedError raise NotImplementedError
def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument def pre_encode(self, image: np.ndarray) -> T.Any: # pylint: disable=unused-argument
""" Some writer plugins support the pre-encoding of images prior to saving out. As """ Some writer plugins support the pre-encoding of images prior to saving out. As
patching is done in multiple threads, but writing is done in a single thread, it can patching is done in multiple threads, but writing is done in a single thread, it can
speed up the process to do any pre-encoding as part of the converter process. speed up the process to do any pre-encoding as part of the converter process.

View file

@ -1,9 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Video output writer for faceswap.py converter """ """ Video output writer for faceswap.py converter """
from __future__ import annotations
import os import os
import typing as T
from math import ceil from math import ceil
from subprocess import CalledProcessError, check_output, STDOUT from subprocess import CalledProcessError, check_output, STDOUT
from typing import cast, Generator, List, Optional, Tuple
import imageio import imageio
import imageio_ffmpeg as im_ffm import imageio_ffmpeg as im_ffm
@ -11,6 +13,9 @@ import numpy as np
from ._base import Output, logger from ._base import Output, logger
if T.TYPE_CHECKING:
from collections.abc import Generator
class Writer(Output): class Writer(Output):
""" Video output writer using imageio-ffmpeg. """ Video output writer using imageio-ffmpeg.
@ -32,7 +37,7 @@ class Writer(Output):
def __init__(self, def __init__(self,
output_folder: str, output_folder: str,
total_count: int, total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]], frame_ranges: list[tuple[int, int]] | None,
source_video: str, source_video: str,
**kwargs) -> None: **kwargs) -> None:
super().__init__(output_folder, **kwargs) super().__init__(output_folder, **kwargs)
@ -40,11 +45,11 @@ class Writer(Output):
total_count, frame_ranges, source_video) total_count, frame_ranges, source_video)
self._source_video: str = source_video self._source_video: str = source_video
self._output_filename: str = self._get_output_filename() self._output_filename: str = self._get_output_filename()
self._frame_ranges: Optional[List[Tuple[int, int]]] = frame_ranges self._frame_ranges: list[tuple[int, int]] | None = frame_ranges
self.frame_order: List[int] = self._set_frame_order(total_count) self.frame_order: list[int] = self._set_frame_order(total_count)
self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame self._output_dimensions: str | None = None # Fix dims on 1st received frame
# Need to know dimensions of first frame, so set writer then # Need to know dimensions of first frame, so set writer then
self._writer: Optional[Generator[None, np.ndarray, None]] = None self._writer: Generator[None, np.ndarray, None] | None = None
@property @property
def _valid_tunes(self) -> dict: def _valid_tunes(self) -> dict:
@ -63,7 +68,7 @@ class Writer(Output):
return retval return retval
@property @property
def _output_params(self) -> List[str]: def _output_params(self) -> list[str]:
""" list: The FFMPEG Output parameters """ """ list: The FFMPEG Output parameters """
codec = self.config["codec"] codec = self.config["codec"]
tune = self.config["tune"] tune = self.config["tune"]
@ -86,11 +91,11 @@ class Writer(Output):
return output_args return output_args
@property @property
def _audio_codec(self) -> Optional[str]: def _audio_codec(self) -> str | None:
""" str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default) """ str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default)
or ``None`` if skip muxing has been selected in configuration options, or if frame ranges or ``None`` if skip muxing has been selected in configuration options, or if frame ranges
have been passed in the command line arguments. """ have been passed in the command line arguments. """
retval: Optional[str] = "copy" retval: str | None = "copy"
if self.config["skip_mux"]: if self.config["skip_mux"]:
logger.info("Skipping audio muxing due to configuration settings.") logger.info("Skipping audio muxing due to configuration settings.")
retval = None retval = None
@ -169,7 +174,7 @@ class Writer(Output):
logger.info("Outputting to: '%s'", retval) logger.info("Outputting to: '%s'", retval)
return retval return retval
def _set_frame_order(self, total_count: int) -> List[int]: def _set_frame_order(self, total_count: int) -> list[int]:
""" Obtain the full list of frames to be converted in order. """ Obtain the full list of frames to be converted in order.
Parameters Parameters
@ -191,7 +196,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval) logger.debug("frame_order: %s", retval)
return retval return retval
def _get_writer(self, frame_dims: Tuple[int, int]) -> Generator[None, np.ndarray, None]: def _get_writer(self, frame_dims: tuple[int, int]) -> Generator[None, np.ndarray, None]:
""" Add the requested encoding options and return the writer. """ Add the requested encoding options and return the writer.
Parameters Parameters
@ -238,13 +243,13 @@ class Writer(Output):
logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined] logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined]
filename, image.shape) filename, image.shape)
if not self._output_dimensions: if not self._output_dimensions:
input_dims = cast(Tuple[int, int], image.shape[:2]) input_dims = T.cast(tuple[int, int], image.shape[:2])
self._set_dimensions(input_dims) self._set_dimensions(input_dims)
self._writer = self._get_writer(input_dims) self._writer = self._get_writer(input_dims)
self.cache_frame(filename, image) self.cache_frame(filename, image)
self._save_from_cache() self._save_from_cache()
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None: def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received. """ Set the attribute :attr:`_output_dimensions` based on the first frame received.
This protects against different sized images coming in and ensures all images are written This protects against different sized images coming in and ensures all images are written
to ffmpeg at the same size. Dimensions are mapped to a macro block size 8. to ffmpeg at the same size. Dimensions are mapped to a macro block size 8.

View file

@ -1,14 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Animated GIF writer for faceswap.py converter """ """ Animated GIF writer for faceswap.py converter """
from __future__ import annotations
import os import os
from typing import Optional, List, Tuple, TYPE_CHECKING import typing as T
import cv2 import cv2
import imageio import imageio
from ._base import Output, logger from ._base import Output, logger
if TYPE_CHECKING: if T.TYPE_CHECKING:
from imageio.core import format as im_format # noqa:F401 from imageio.core import format as im_format # noqa:F401
@ -31,15 +32,16 @@ class Writer(Output):
def __init__(self, def __init__(self,
output_folder: str, output_folder: str,
total_count: int, total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]], frame_ranges: list[tuple[int, int]] | None,
**kwargs) -> None: **kwargs) -> None:
logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges) logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges)
super().__init__(output_folder, **kwargs) super().__init__(output_folder, **kwargs)
self.frame_order: List[int] = self._set_frame_order(total_count, frame_ranges) self.frame_order: list[int] = self._set_frame_order(total_count, frame_ranges)
self._output_dimensions: Optional[Tuple[int, int]] = None # Fix dims on 1st received frame # Fix dims on 1st received frame
self._output_dimensions: tuple[int, int] | None = None
# Need to know dimensions of first frame, so set writer then # Need to know dimensions of first frame, so set writer then
self._writer: Optional[imageio.plugins.pillowmulti.GIFFormat.Writer] = None self._writer: imageio.plugins.pillowmulti.GIFFormat.Writer | None = None
self._gif_file: Optional[str] = None # Set filename based on first file seen self._gif_file: str | None = None # Set filename based on first file seen
@property @property
def _gif_params(self) -> dict: def _gif_params(self) -> dict:
@ -50,7 +52,7 @@ class Writer(Output):
@staticmethod @staticmethod
def _set_frame_order(total_count: int, def _set_frame_order(total_count: int,
frame_ranges: Optional[List[Tuple[int, int]]]) -> List[int]: frame_ranges: list[tuple[int, int]] | None) -> list[int]:
""" Obtain the full list of frames to be converted in order. """ Obtain the full list of frames to be converted in order.
Parameters Parameters
@ -75,7 +77,7 @@ class Writer(Output):
logger.debug("frame_order: %s", retval) logger.debug("frame_order: %s", retval)
return retval return retval
def _get_writer(self) -> "im_format.Format.Writer": def _get_writer(self) -> im_format.Format.Writer:
""" Obtain the GIF writer with the requested GIF encoding options. """ Obtain the GIF writer with the requested GIF encoding options.
Returns Returns
@ -145,7 +147,7 @@ class Writer(Output):
self._gif_file = retval self._gif_file = retval
logger.info("Outputting to: '%s'", self._gif_file) logger.info("Outputting to: '%s'", self._gif_file)
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None: def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
""" Set the attribute :attr:`_output_dimensions` based on the first frame received. This """ Set the attribute :attr:`_output_dimensions` based on the first frame received. This
protects against different sized images coming in and ensure all images get written to the protects against different sized images coming in and ensure all images get written to the
Gif at the sema dimensions. """ Gif at the sema dimensions. """

View file

@ -2,8 +2,6 @@
""" Image output writer for faceswap.py converter """ Image output writer for faceswap.py converter
Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO
""" """
from typing import List, Tuple
import cv2 import cv2
import numpy as np import numpy as np
@ -37,7 +35,7 @@ class Writer(Output):
"transparency. Changing output format to 'png'") "transparency. Changing output format to 'png'")
self.config["format"] = "png" self.config["format"] = "png"
def _get_save_args(self) -> Tuple[int, ...]: def _get_save_args(self) -> tuple[int, ...]:
""" Obtain the save parameters for the file format. """ Obtain the save parameters for the file format.
Returns Returns
@ -46,7 +44,7 @@ class Writer(Output):
The OpenCV specific arguments for the selected file format The OpenCV specific arguments for the selected file format
""" """
filetype = self.config["format"] filetype = self.config["format"]
args: Tuple[int, ...] = tuple() args: tuple[int, ...] = tuple()
if filetype == "jpg" and self.config["jpg_quality"] > 0: if filetype == "jpg" and self.config["jpg_quality"] > 0:
args = (cv2.IMWRITE_JPEG_QUALITY, args = (cv2.IMWRITE_JPEG_QUALITY,
self.config["jpg_quality"]) self.config["jpg_quality"])
@ -56,7 +54,7 @@ class Writer(Output):
logger.debug(args) logger.debug(args)
return args return args
def write(self, filename: str, image: List[bytes]) -> None: def write(self, filename: str, image: list[bytes]) -> None:
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out """ Write out the pre-encoded image to disk. If separate mask has been selected, write out
the encoded mask to a sub-folder in the output directory. the encoded mask to a sub-folder in the output directory.
@ -77,7 +75,7 @@ class Writer(Output):
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err) logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
def pre_encode(self, image: np.ndarray) -> List[bytes]: def pre_encode(self, image: np.ndarray) -> list[bytes]:
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker. """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker.
Parameters Parameters

View file

@ -1,7 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Image output writer for faceswap.py converter """ """ Image output writer for faceswap.py converter """
from typing import Dict, List, Union
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
@ -25,7 +23,7 @@ class Writer(Output):
super().__init__(output_folder, **kwargs) super().__init__(output_folder, **kwargs)
self._check_transparency_format() self._check_transparency_format()
# Correct format namings for writing to byte stream # Correct format namings for writing to byte stream
self._format_dict = dict(jpg="JPEG", jp2="JPEG 2000", tif="TIFF") self._format_dict = {"jpg": "JPEG", "jp2": "JPEG 2000", "tif": "TIFF"}
self._separate_mask = self.config["draw_transparent"] and self.config["separate_mask"] self._separate_mask = self.config["draw_transparent"] and self.config["separate_mask"]
self._kwargs = self._get_save_kwargs() self._kwargs = self._get_save_kwargs()
@ -38,7 +36,7 @@ class Writer(Output):
"transparency. Changing output format to 'png'") "transparency. Changing output format to 'png'")
self.config["format"] = "png" self.config["format"] = "png"
def _get_save_kwargs(self) -> Dict[str, Union[bool, int, str]]: def _get_save_kwargs(self) -> dict[str, bool | int | str]:
""" Return the save parameters for the file format """ Return the save parameters for the file format
Returns Returns
@ -59,7 +57,7 @@ class Writer(Output):
logger.debug(kwargs) logger.debug(kwargs)
return kwargs return kwargs
def write(self, filename: str, image: List[BytesIO]) -> None: def write(self, filename: str, image: list[BytesIO]) -> None:
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out """ Write out the pre-encoded image to disk. If separate mask has been selected, write out
the encoded mask to a sub-folder in the output directory. the encoded mask to a sub-folder in the output directory.
@ -80,7 +78,7 @@ class Writer(Output):
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
logger.error("Failed to save image '%s'. Original Error: %s", filename, err) logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
def pre_encode(self, image: np.ndarray) -> List[BytesIO]: def pre_encode(self, image: np.ndarray) -> list[BytesIO]:
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker
Parameters Parameters

View file

@ -2,12 +2,11 @@
""" Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and """ Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and
:mod:`~plugins.extract.mask` Plugins :mod:`~plugins.extract.mask` Plugins
""" """
from __future__ import annotations
import logging import logging
import sys import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Generator, List, Optional,
Sequence, Union, Tuple, TYPE_CHECKING)
import numpy as np import numpy as np
from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa
@ -18,12 +17,8 @@ from lib.utils import GetModel, FaceswapError
from ._config import Config from ._config import Config
from .pipeline import ExtractMedia from .pipeline import ExtractMedia
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal from collections.abc import Callable, Generator, Sequence
else:
from typing import Literal
if TYPE_CHECKING:
from queue import Queue from queue import Queue
import cv2 import cv2
from lib.align import DetectedFace from lib.align import DetectedFace
@ -37,7 +32,7 @@ logger = logging.getLogger(__name__)
# TODO Run with warnings mode # TODO Run with warnings mode
def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str, Any]: def _get_config(plugin_name: str, configfile: str | None = None) -> dict[str, T.Any]:
""" Return the configuration for the requested model """ Return the configuration for the requested model
Parameters Parameters
@ -56,7 +51,7 @@ def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str,
return Config(plugin_name, configfile=configfile).config_dict return Config(plugin_name, configfile=configfile).config_dict
BatchType = Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"] BatchType = T.Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"]
@dataclass @dataclass
@ -84,13 +79,12 @@ class ExtractorBatch:
data: dict data: dict
Any specific data required during the processing phase for a particular plugin Any specific data required during the processing phase for a particular plugin
""" """
image: List[np.ndarray] = field(default_factory=list) image: list[np.ndarray] = field(default_factory=list)
detected_faces: Sequence[Union["DetectedFace", detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list)
List["DetectedFace"]]] = field(default_factory=list) filename: list[str] = field(default_factory=list)
filename: List[str] = field(default_factory=list)
feed: np.ndarray = np.array([]) feed: np.ndarray = np.array([])
prediction: np.ndarray = np.array([]) prediction: np.ndarray = np.array([])
data: List[Dict[str, Any]] = field(default_factory=list) data: list[dict[str, T.Any]] = field(default_factory=list)
class Extractor(): class Extractor():
@ -157,10 +151,10 @@ class Extractor():
""" """
def __init__(self, def __init__(self,
git_model_id: Optional[int] = None, git_model_id: int | None = None,
model_filename: Optional[Union[str, List[str]]] = None, model_filename: str | list[str] | None = None,
exclude_gpus: Optional[List[int]] = None, exclude_gpus: list[int] | None = None,
configfile: Optional[str] = None, configfile: str | None = None,
instance: int = 0) -> None: instance: int = 0) -> None:
logger.debug("Initializing %s: (git_model_id: %s, model_filename: %s, exclude_gpus: %s, " logger.debug("Initializing %s: (git_model_id: %s, model_filename: %s, exclude_gpus: %s, "
"configfile: %s, instance: %s, )", self.__class__.__name__, git_model_id, "configfile: %s, instance: %s, )", self.__class__.__name__, git_model_id,
@ -176,9 +170,9 @@ class Extractor():
be a list of strings """ be a list of strings """
# << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> # # << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> #
self.name: Optional[str] = None self.name: str | None = None
self.input_size = 0 self.input_size = 0
self.color_format: Literal["BGR", "RGB", "GRAY"] = "BGR" self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR"
self.vram = 0 self.vram = 0
self.vram_warnings = 0 # Will run at this with warnings self.vram_warnings = 0 # Will run at this with warnings
self.vram_per_batch = 0 self.vram_per_batch = 0
@ -187,7 +181,7 @@ class Extractor():
self.queue_size = 1 self.queue_size = 1
""" int: Queue size for all internal queues. Set in :func:`initialize()` """ """ int: Queue size for all internal queues. Set in :func:`initialize()` """
self.model: Optional[Union["KSession", "cv2.dnn.Net"]] = None self.model: KSession | cv2.dnn.Net | None = None
"""varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """ """varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """
# For detectors that support batching, this should be set to the calculated batch size # For detectors that support batching, this should be set to the calculated batch size
@ -196,26 +190,26 @@ class Extractor():
""" int: Batchsize for feeding this model. The number of images the model should """ int: Batchsize for feeding this model. The number of images the model should
feed through at once. """ feed through at once. """
self._queues: Dict[str, "Queue"] = {} self._queues: dict[str, Queue] = {}
""" dict: in + out queues and internal queues for this plugin, """ """ dict: in + out queues and internal queues for this plugin, """
self._threads: List[MultiThread] = [] self._threads: list[MultiThread] = []
""" list: Internal threads for this plugin """ """ list: Internal threads for this plugin """
self._extract_media: Dict[str, ExtractMedia] = {} self._extract_media: dict[str, ExtractMedia] = {}
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being """ dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
processed. Stored at input for pairing back up on output of extractor process """ processed. Stored at input for pairing back up on output of extractor process """
# << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> # # << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> #
self._plugin_type: Optional[Literal["align", "detect", "recognition", "mask"]] = None self._plugin_type: T.Literal["align", "detect", "recognition", "mask"] | None = None
""" str: Plugin type. ``detect`, ``align``, ``recognise`` or ``mask`` set in """ str: Plugin type. ``detect`, ``align``, ``recognise`` or ``mask`` set in
``<plugin_type>._base`` """ ``<plugin_type>._base`` """
# << Objects for splitting frame's detected faces and rejoining them >> # << Objects for splitting frame's detected faces and rejoining them >>
# << for post-detector pliugins >> # << for post-detector pliugins >>
self._faces_per_filename: Dict[str, int] = {} # Tracking for recompiling batches self._faces_per_filename: dict[str, int] = {} # Tracking for recompiling batches
self._rollover: Optional[ExtractMedia] = None # batch rollover items self._rollover: ExtractMedia | None = None # batch rollover items
self._output_faces: List["DetectedFace"] = [] # Recompiled output faces from plugin self._output_faces: list[DetectedFace] = [] # Recompiled output faces from plugin
logger.debug("Initialized _base %s", self.__class__.__name__) logger.debug("Initialized _base %s", self.__class__.__name__)
@ -361,7 +355,7 @@ class Extractor():
""" """
raise NotImplementedError raise NotImplementedError
def get_batch(self, queue: "Queue") -> Tuple[bool, BatchType]: def get_batch(self, queue: Queue) -> tuple[bool, BatchType]:
""" **Override method** (at `<plugin_type>` level) """ **Override method** (at `<plugin_type>` level)
This method should be overridden at the `<plugin_type>` level (IE. This method should be overridden at the `<plugin_type>` level (IE.
@ -403,7 +397,7 @@ class Extractor():
for thread in self._threads: for thread in self._threads:
thread.check_and_raise_error() thread.check_and_raise_error()
def rollover_collector(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia]: def rollover_collector(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia:
""" For extractors after the Detectors, the number of detected faces per frame vs extractor """ For extractors after the Detectors, the number of detected faces per frame vs extractor
batch size mean that faces will need to be split/re-joined with frames. The rollover batch size mean that faces will need to be split/re-joined with frames. The rollover
collector can be used to rollover items that don't fit in a batch. collector can be used to rollover items that don't fit in a batch.
@ -425,7 +419,7 @@ class Extractor():
if self._rollover is not None: if self._rollover is not None:
logger.trace("Getting from _rollover: (filename: `%s`, faces: %s)", # type:ignore logger.trace("Getting from _rollover: (filename: `%s`, faces: %s)", # type:ignore
self._rollover.filename, len(self._rollover.detected_faces)) self._rollover.filename, len(self._rollover.detected_faces))
item: Union[Literal["EOF"], ExtractMedia] = self._rollover item: T.Literal["EOF"] | ExtractMedia = self._rollover
self._rollover = None self._rollover = None
else: else:
next_item = self._get_item(queue) next_item = self._get_item(queue)
@ -442,9 +436,8 @@ class Extractor():
# <<< INIT METHODS >>> # # <<< INIT METHODS >>> #
@classmethod @classmethod
def _get_model(cls, def _get_model(cls,
git_model_id: Optional[int], git_model_id: int | None,
model_filename: Optional[Union[str, List[str]]] model_filename: str | list[str] | None) -> str | list[str] | None:
) -> Optional[Union[str, List[str]]]:
""" Check if model is available, if not, download and unzip it """ """ Check if model is available, if not, download and unzip it """
if model_filename is None: if model_filename is None:
logger.debug("No model_filename specified. Returning None") logger.debug("No model_filename specified. Returning None")
@ -496,9 +489,9 @@ class Extractor():
self.name, self._plugin_type.title(), self.batchsize) self.name, self._plugin_type.title(), self.batchsize)
def _add_queues(self, def _add_queues(self,
in_queue: "Queue", in_queue: Queue,
out_queue: "Queue", out_queue: Queue,
queues: List[str]) -> None: queues: list[str]) -> None:
""" Add the queues """ Add the queues
in_queue and out_queue should be previously created queue manager queues. in_queue and out_queue should be previously created queue manager queues.
queues should be a list of queue names """ queues should be a list of queue names """
@ -533,8 +526,8 @@ class Extractor():
def _add_thread(self, def _add_thread(self,
name: str, name: str,
function: Callable[[BatchType], BatchType], function: Callable[[BatchType], BatchType],
in_queue: "Queue", in_queue: Queue,
out_queue: "Queue") -> None: out_queue: Queue) -> None:
""" Add a MultiThread thread to self._threads """ """ Add a MultiThread thread to self._threads """
logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)", logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)",
name, function, in_queue, out_queue) name, function, in_queue, out_queue)
@ -546,8 +539,8 @@ class Extractor():
logger.debug("Added thread: %s", name) logger.debug("Added thread: %s", name)
def _obtain_batch_item(self, function: Callable[[BatchType], BatchType], def _obtain_batch_item(self, function: Callable[[BatchType], BatchType],
in_queue: "Queue", in_queue: Queue,
out_queue: "Queue") -> Optional[BatchType]: out_queue: Queue) -> BatchType | None:
""" Obtain the batch item from the in queue for the current process. """ Obtain the batch item from the in queue for the current process.
Parameters Parameters
@ -564,7 +557,7 @@ class Extractor():
:class:`ExtractorBatch` or ``None`` :class:`ExtractorBatch` or ``None``
The batch, if one exists, or ``None`` if queue is exhausted The batch, if one exists, or ``None`` if queue is exhausted
""" """
batch: Union[Literal["EOF"], BatchType, ExtractMedia] batch: T.Literal["EOF"] | BatchType | ExtractMedia
if function.__name__ == "_process_input": # Process input items to batches if function.__name__ == "_process_input": # Process input items to batches
exhausted, batch = self.get_batch(in_queue) exhausted, batch = self.get_batch(in_queue)
if exhausted: if exhausted:
@ -585,8 +578,8 @@ class Extractor():
def _thread_process(self, def _thread_process(self,
function: Callable[[BatchType], BatchType], function: Callable[[BatchType], BatchType],
in_queue: "Queue", in_queue: Queue,
out_queue: "Queue") -> None: out_queue: Queue) -> None:
""" Perform a plugin function in a thread """ Perform a plugin function in a thread
Parameters Parameters
@ -629,7 +622,7 @@ class Extractor():
out_queue.put("EOF") out_queue.put("EOF")
# <<< QUEUE METHODS >>> # # <<< QUEUE METHODS >>> #
def _get_item(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia, BatchType]: def _get_item(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia | BatchType:
""" Yield one item from a queue """ """ Yield one item from a queue """
item = queue.get() item = queue.get()
if isinstance(item, ExtractMedia): if isinstance(item, ExtractMedia):

View file

@ -12,12 +12,12 @@ For each source item, the plugin must pass a dict to finalize containing:
>>> "landmarks": [list of 68 point face landmarks] >>> "landmarks": [list of 68 point face landmarks]
>>> "detected_faces": [<list of DetectedFace objects>]} >>> "detected_faces": [<list of DetectedFace objects>]}
""" """
from __future__ import annotations
import logging import logging
import sys import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
from time import sleep from time import sleep
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
@ -28,12 +28,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractMedia, ExtractorBatch from plugins.extract._base import BatchType, Extractor, ExtractMedia, ExtractorBatch
from .processing import AlignedFilter, ReAlign from .processing import AlignedFilter, ReAlign
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal from collections.abc import Generator
else:
from typing import Literal
if TYPE_CHECKING:
from queue import Queue from queue import Queue
from lib.align import DetectedFace from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType from lib.align.aligned_face import CenteringType
@ -77,9 +73,9 @@ class AlignerBatch(ExtractorBatch):
The masks used to filter out re-feed values for passing to the re-aligner. The masks used to filter out re-feed values for passing to the re-aligner.
""" """
batch_id: int = 0 batch_id: int = 0
detected_faces: List["DetectedFace"] = field(default_factory=list) detected_faces: list[DetectedFace] = field(default_factory=list)
landmarks: np.ndarray = np.array([]) landmarks: np.ndarray = np.array([])
refeeds: List[np.ndarray] = field(default_factory=list) refeeds: list[np.ndarray] = field(default_factory=list)
second_pass: bool = False second_pass: bool = False
second_pass_masks: np.ndarray = np.array([]) second_pass_masks: np.ndarray = np.array([])
@ -142,11 +138,11 @@ class Aligner(Extractor): # pylint:disable=abstract-method
""" """
def __init__(self, def __init__(self,
git_model_id: Optional[int] = None, git_model_id: int | None = None,
model_filename: Optional[str] = None, model_filename: str | None = None,
configfile: Optional[str] = None, configfile: str | None = None,
instance: int = 0, instance: int = 0,
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None, normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
re_feed: int = 0, re_feed: int = 0,
re_align: bool = False, re_align: bool = False,
disable_filter: bool = False, disable_filter: bool = False,
@ -160,9 +156,9 @@ class Aligner(Extractor): # pylint:disable=abstract-method
instance=instance, instance=instance,
**kwargs) **kwargs)
self._plugin_type = "align" self._plugin_type = "align"
self.realign_centering: "CenteringType" = "face" # overide for plugin specific centering self.realign_centering: CenteringType = "face" # overide for plugin specific centering
self._eof_seen = False self._eof_seen = False
self._normalize_method: Optional[Literal["clahe", "hist", "mean"]] = None self._normalize_method: T.Literal["clahe", "hist", "mean"] | None = None
self._re_feed = re_feed self._re_feed = re_feed
self._filter = AlignedFilter(feature_filter=self.config["aligner_features"], self._filter = AlignedFilter(feature_filter=self.config["aligner_features"],
min_scale=self.config["aligner_min_scale"], min_scale=self.config["aligner_min_scale"],
@ -181,8 +177,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def set_normalize_method(self, def set_normalize_method(self, method: T.Literal["none", "clahe", "hist", "mean"] | None
method: Optional[Literal["none", "clahe", "hist", "mean"]]) -> None: ) -> None:
""" Set the normalization method for feeding faces into the aligner. """ Set the normalization method for feeding faces into the aligner.
Parameters Parameters
@ -191,14 +187,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method
The normalization method to apply to faces prior to feeding into the model The normalization method to apply to faces prior to feeding into the model
""" """
method = None if method is None or method.lower() == "none" else method method = None if method is None or method.lower() == "none" else method
self._normalize_method = cast(Optional[Literal["clahe", "hist", "mean"]], method) self._normalize_method = T.cast(T.Literal["clahe", "hist", "mean"] | None, method)
def initialize(self, *args, **kwargs) -> None: def initialize(self, *args, **kwargs) -> None:
""" Add a call to add model input size to the re-aligner """ """ Add a call to add model input size to the re-aligner """
self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering) self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering)
super().initialize(*args, **kwargs) super().initialize(*args, **kwargs)
def _handle_realigns(self, queue: "Queue") -> Optional[Tuple[bool, AlignerBatch]]: def _handle_realigns(self, queue: Queue) -> tuple[bool, AlignerBatch] | None:
""" Handle any items waiting for a second pass through the aligner. """ Handle any items waiting for a second pass through the aligner.
If EOF has been recieved and items are still being processed through the first pass If EOF has been recieved and items are still being processed through the first pass
@ -242,7 +238,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return None return None
def get_batch(self, queue: "Queue") -> Tuple[bool, AlignerBatch]: def get_batch(self, queue: Queue) -> tuple[bool, AlignerBatch]:
""" Get items for inputting into the aligner from the queue in batches """ Get items for inputting into the aligner from the queue in batches
Items are returned from the ``queue`` in batches of Items are returned from the ``queue`` in batches of
@ -548,7 +544,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
"\n3) Enable 'Single Process' mode.") "\n3) Enable 'Single Process' mode.")
raise FaceswapError(msg) from err raise FaceswapError(msg) from err
def _process_refeeds(self, batch: AlignerBatch) -> List[AlignerBatch]: def _process_refeeds(self, batch: AlignerBatch) -> list[AlignerBatch]:
""" Process the output for each selected re-feed """ Process the output for each selected re-feed
Parameters Parameters
@ -562,7 +558,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
List of :class:`AlignerBatch` objects. Each object in the list contains the List of :class:`AlignerBatch` objects. Each object in the list contains the
results for each selected re-feed results for each selected re-feed
""" """
retval: List[AlignerBatch] = [] retval: list[AlignerBatch] = []
if batch.second_pass: if batch.second_pass:
# Re-insert empty sub-patches for re-population in ReAlign for filtered out batches # Re-insert empty sub-patches for re-population in ReAlign for filtered out batches
selected_idx = 0 selected_idx = 0
@ -605,8 +601,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return retval return retval
def _get_refeed_filter_masks(self, def _get_refeed_filter_masks(self,
subbatches: List[AlignerBatch], subbatches: list[AlignerBatch],
original_masks: Optional[np.ndarray] = None) -> np.ndarray: original_masks: np.ndarray | None = None) -> np.ndarray:
""" Obtain the boolean mask array for masking out failed re-feed results if filter refeed """ Obtain the boolean mask array for masking out failed re-feed results if filter refeed
has been selected has been selected
@ -663,7 +659,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
landmarks.shape) landmarks.shape)
return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32") return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32")
def _process_output_first_pass(self, subbatches: List[AlignerBatch]) -> Tuple[np.ndarray, def _process_output_first_pass(self, subbatches: list[AlignerBatch]) -> tuple[np.ndarray,
np.ndarray]: np.ndarray]:
""" Process the output from the aligner if this is the first or only pass. """ Process the output from the aligner if this is the first or only pass.
@ -696,7 +692,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
return all_landmarks, masks return all_landmarks, masks
def _process_output_second_pass(self, def _process_output_second_pass(self,
subbatches: List[AlignerBatch], subbatches: list[AlignerBatch],
masks: np.ndarray) -> np.ndarray: masks: np.ndarray) -> np.ndarray:
""" Process the output from the aligner if this is the first or only pass. """ Process the output from the aligner if this is the first or only pass.

View file

@ -1,21 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Processing methods for aligner plugins """ """ Processing methods for aligner plugins """
from __future__ import annotations
import logging import logging
import sys import typing as T
from threading import Lock from threading import Lock
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
from lib.align import AlignedFace from lib.align import AlignedFace
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
from lib.align import DetectedFace from lib.align import DetectedFace
from .aligner import AlignerBatch from .aligner import AlignerBatch
from lib.align.aligned_face import CenteringType from lib.align.aligned_face import CenteringType
@ -72,16 +67,16 @@ class AlignedFilter():
min_scale > 0.0 or min_scale > 0.0 or
distance > 0.0 or distance > 0.0 or
roll > 0.0) roll > 0.0)
self._counts: Dict[str, int] = dict(features=0, self._counts: dict[str, int] = {"features": 0,
min_scale=0, "min_scale": 0,
max_scale=0, "max_scale": 0,
distance=0, "distance": 0,
roll=0) "roll": 0}
logger.debug("Initialized %s: ", self.__class__.__name__) logger.debug("Initialized %s: ", self.__class__.__name__)
def _scale_test(self, def _scale_test(self,
face: AlignedFace, face: AlignedFace,
minimum_dimension: int) -> Optional[Literal["min", "max"]]: minimum_dimension: int) -> T.Literal["min", "max"] | None:
""" Test if a face is below or above the min/max size thresholds. Returns as soon as a test """ Test if a face is below or above the min/max size thresholds. Returns as soon as a test
fails. fails.
@ -116,9 +111,9 @@ class AlignedFilter():
def _handle_filtered(self, def _handle_filtered(self,
key: str, key: str,
face: "DetectedFace", face: DetectedFace,
faces: List["DetectedFace"], faces: list[DetectedFace],
sub_folders: List[Optional[str]], sub_folders: list[str | None],
sub_folder_index: int) -> None: sub_folder_index: int) -> None:
""" Add the filtered item to the filter counts. """ Add the filtered item to the filter counts.
@ -145,8 +140,8 @@ class AlignedFilter():
faces.append(face) faces.append(face)
sub_folders[sub_folder_index] = f"_align_filt_{key}" sub_folders[sub_folder_index] = f"_align_filt_{key}"
def __call__(self, faces: List["DetectedFace"], minimum_dimension: int def __call__(self, faces: list[DetectedFace], minimum_dimension: int
) -> Tuple[List["DetectedFace"], List[Optional[str]]]: ) -> tuple[list[DetectedFace], list[str | None]]:
""" Apply the filter to the incoming batch """ Apply the filter to the incoming batch
Parameters Parameters
@ -165,11 +160,11 @@ class AlignedFilter():
List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones`` List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones``
and sub folder names corresponding the filtered face location and sub folder names corresponding the filtered face location
""" """
sub_folders: List[Optional[str]] = [None for _ in range(len(faces))] sub_folders: list[str | None] = [None for _ in range(len(faces))]
if not self._active: if not self._active:
return faces, sub_folders return faces, sub_folders
retval: List["DetectedFace"] = [] retval: list[DetectedFace] = []
for idx, face in enumerate(faces): for idx, face in enumerate(faces):
aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face") aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face")
@ -194,8 +189,8 @@ class AlignedFilter():
return retval, sub_folders return retval, sub_folders
def filtered_mask(self, def filtered_mask(self,
batch: "AlignerBatch", batch: AlignerBatch,
skip: Optional[Union[np.ndarray, List[int]]] = None) -> np.ndarray: skip: np.ndarray | list[int] | None = None) -> np.ndarray:
""" Obtain a list of boolean values for the given batch indicating whether they pass the """ Obtain a list of boolean values for the given batch indicating whether they pass the
filter test. filter test.
@ -262,13 +257,14 @@ class ReAlign():
self._active = active self._active = active
self._do_refeeds = do_refeeds self._do_refeeds = do_refeeds
self._do_filter = do_filter self._do_filter = do_filter
self._centering: "CenteringType" = "face" self._centering: CenteringType = "face"
self._size = 0 self._size = 0
self._tracked_lock = Lock() self._tracked_lock = Lock()
self._tracked_batchs: Dict[int, Dict[Literal["filtered_landmarks"], List[np.ndarray]]] = {} self._tracked_batchs: dict[int,
dict[T.Literal["filtered_landmarks"], list[np.ndarray]]] = {}
# TODO. Probably does not need to be a list, just alignerbatch # TODO. Probably does not need to be a list, just alignerbatch
self._queue_lock = Lock() self._queue_lock = Lock()
self._queued: List["AlignerBatch"] = [] self._queued: list[AlignerBatch] = []
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@property @property
@ -301,7 +297,7 @@ class ReAlign():
with self._tracked_lock: with self._tracked_lock:
return bool(self._tracked_batchs) return bool(self._tracked_batchs)
def set_input_size_and_centering(self, input_size: int, centering: "CenteringType") -> None: def set_input_size_and_centering(self, input_size: int, centering: CenteringType) -> None:
""" Set the input size of the loaded plugin once the model has been loaded """ Set the input size of the loaded plugin once the model has been loaded
Parameters Parameters
@ -344,7 +340,7 @@ class ReAlign():
with self._tracked_lock: with self._tracked_lock:
del self._tracked_batchs[batch_id] del self._tracked_batchs[batch_id]
def add_batch(self, batch: "AlignerBatch") -> None: def add_batch(self, batch: AlignerBatch) -> None:
""" Add first pass alignments to the queue for picking up for re-alignment, update their """ Add first pass alignments to the queue for picking up for re-alignment, update their
:attr:`second_pass` attribute to ``True`` and clear attributes not required. :attr:`second_pass` attribute to ``True`` and clear attributes not required.
@ -362,7 +358,7 @@ class ReAlign():
batch.data = [] batch.data = []
self._queued.append(batch) self._queued.append(batch)
def get_batch(self) -> "AlignerBatch": def get_batch(self) -> AlignerBatch:
""" Retrieve the next batch currently queued for re-alignment """ Retrieve the next batch currently queued for re-alignment
Returns Returns
@ -376,7 +372,7 @@ class ReAlign():
retval.filename) retval.filename)
return retval return retval
def process_batch(self, batch: "AlignerBatch") -> List[np.ndarray]: def process_batch(self, batch: AlignerBatch) -> list[np.ndarray]:
""" Pre process a batch object for re-aligning through the aligner. """ Pre process a batch object for re-aligning through the aligner.
Parameters Parameters
@ -391,8 +387,8 @@ class ReAlign():
""" """
logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined] logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined]
batch.filename, [b.shape for b in batch.landmarks]) batch.filename, [b.shape for b in batch.landmarks])
retval: List[np.ndarray] = [] retval: list[np.ndarray] = []
filtered_landmarks: List[np.ndarray] = [] filtered_landmarks: list[np.ndarray] = []
for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks): for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks):
if not np.all(masks): # At least one face has not already been filtered if not np.all(masks): # At least one face has not already been filtered
aligned_faces = [AlignedFace(lms, aligned_faces = [AlignedFace(lms,
@ -415,7 +411,7 @@ class ReAlign():
batch.landmarks = np.array([]) # Clear the old landmarks batch.landmarks = np.array([]) # Clear the old landmarks
return retval return retval
def _transform_to_frame(self, batch: "AlignerBatch") -> np.ndarray: def _transform_to_frame(self, batch: AlignerBatch) -> np.ndarray:
""" Transform the predicted landmarks from the aligned face image back into frame """ Transform the predicted landmarks from the aligned face image back into frame
co-ordinates co-ordinates
@ -430,14 +426,14 @@ class ReAlign():
:class:`numpy.ndarray` :class:`numpy.ndarray`
The landmarks transformed to frame space The landmarks transformed to frame space
""" """
faces: List[AlignedFace] = batch.data[0]["aligned_faces"] faces: list[AlignedFace] = batch.data[0]["aligned_faces"]
retval = np.array([aligned.transform_points(landmarks, invert=True) retval = np.array([aligned.transform_points(landmarks, invert=True)
for landmarks, aligned in zip(batch.landmarks, faces)]) for landmarks, aligned in zip(batch.landmarks, faces)])
logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined] logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined]
"new max: %s", batch.landmarks.max(), retval.max()) "new max: %s", batch.landmarks.max(), retval.max())
return retval return retval
def _re_insert_filtered(self, batch: "AlignerBatch", masks: np.ndarray) -> np.ndarray: def _re_insert_filtered(self, batch: AlignerBatch, masks: np.ndarray) -> np.ndarray:
""" Re-insert landmarks that were filtered out from the re-align process back into the """ Re-insert landmarks that were filtered out from the re-align process back into the
landmark results landmark results
@ -473,7 +469,7 @@ class ReAlign():
return retval return retval
def process_output(self, subbatches: List["AlignerBatch"], batch_masks: np.ndarray) -> None: def process_output(self, subbatches: list[AlignerBatch], batch_masks: np.ndarray) -> None:
""" Process the output from the re-align pass. """ Process the output from the re-align pass.
- Transform landmarks from aligned face space to face space - Transform landmarks from aligned face space to face space

View file

@ -23,15 +23,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
from __future__ import annotations
import logging import logging
from typing import cast, List, Tuple, TYPE_CHECKING import typing as T
import cv2 import cv2
import numpy as np import numpy as np
from ._base import Aligner, AlignerBatch, BatchType from ._base import Aligner, AlignerBatch, BatchType
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align.detected_face import DetectedFace from lib.align.detected_face import DetectedFace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,9 +90,9 @@ class Align(Aligner):
assert isinstance(batch, AlignerBatch) assert isinstance(batch, AlignerBatch)
lfaces, roi, offsets = self.align_image(batch) lfaces, roi, offsets = self.align_image(batch)
batch.feed = np.array(lfaces)[..., :3] batch.feed = np.array(lfaces)[..., :3]
batch.data.append(dict(roi=roi, offsets=offsets)) batch.data.append({"roi": roi, "offsets": offsets})
def _get_box_and_offset(self, face: "DetectedFace") -> Tuple[List[int], int]: def _get_box_and_offset(self, face: DetectedFace) -> tuple[list[int], int]:
"""Obtain the bounding box and offset from a detected face. """Obtain the bounding box and offset from a detected face.
@ -108,17 +109,17 @@ class Align(Aligner):
The offset of the box (difference between half width vs height) The offset of the box (difference between half width vs height)
""" """
box = cast(List[int], [face.left, box = T.cast(list[int], [face.left,
face.top, face.top,
face.right, face.right,
face.bottom]) face.bottom])
diff_height_width = cast(int, face.height) - cast(int, face.width) diff_height_width = T.cast(int, face.height) - T.cast(int, face.width)
offset = int(abs(diff_height_width / 2)) offset = int(abs(diff_height_width / 2))
return box, offset return box, offset
def align_image(self, batch: AlignerBatch) -> Tuple[List[np.ndarray], def align_image(self, batch: AlignerBatch) -> tuple[list[np.ndarray],
List[List[int]], list[list[int]],
List[Tuple[int, int]]]: list[tuple[int, int]]]:
""" Align the incoming image for prediction """ Align the incoming image for prediction
Parameters Parameters
@ -159,8 +160,8 @@ class Align(Aligner):
@classmethod @classmethod
def move_box(cls, def move_box(cls,
box: List[int], box: list[int],
offset: Tuple[int, int]) -> List[int]: offset: tuple[int, int]) -> list[int]:
"""Move the box to direction specified by vector offset """Move the box to direction specified by vector offset
Parameters Parameters
@ -182,7 +183,7 @@ class Align(Aligner):
return [left, top, right, bottom] return [left, top, right, bottom]
@staticmethod @staticmethod
def get_square_box(box: List[int]) -> List[int]: def get_square_box(box: list[int]) -> list[int]:
"""Get a square box out of the given box, by expanding it. """Get a square box out of the given box, by expanding it.
Parameters Parameters
@ -226,7 +227,7 @@ class Align(Aligner):
return [left, top, right, bottom] return [left, top, right, bottom]
@classmethod @classmethod
def pad_image(cls, box: List[int], image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: def pad_image(cls, box: list[int], image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
"""Pad image if face-box falls outside of boundaries """Pad image if face-box falls outside of boundaries
Parameters Parameters

View file

@ -3,8 +3,9 @@
Code adapted and modified from: Code adapted and modified from:
https://github.com/1adrianb/face-alignment https://github.com/1adrianb/face-alignment
""" """
from __future__ import annotations
import logging import logging
from typing import cast, List, TYPE_CHECKING import typing as T
import cv2 import cv2
import numpy as np import numpy as np
@ -12,7 +13,7 @@ import numpy as np
from lib.model.session import KSession from lib.model.session import KSession
from ._base import Aligner, AlignerBatch, BatchType from ._base import Aligner, AlignerBatch, BatchType
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align import DetectedFace from lib.align import DetectedFace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -76,10 +77,10 @@ class Align(Aligner):
logger.trace("Aligning faces around center") # type:ignore[attr-defined] logger.trace("Aligning faces around center") # type:ignore[attr-defined]
center_scale = self.get_center_scale(batch.detected_faces) center_scale = self.get_center_scale(batch.detected_faces)
batch.feed = np.array(self.crop(batch, center_scale))[..., :3] batch.feed = np.array(self.crop(batch, center_scale))[..., :3]
batch.data.append(dict(center_scale=center_scale)) batch.data.append({"center_scale": center_scale})
logger.trace("Aligned image around center") # type:ignore[attr-defined] logger.trace("Aligned image around center") # type:ignore[attr-defined]
def get_center_scale(self, detected_faces: List["DetectedFace"]) -> np.ndarray: def get_center_scale(self, detected_faces: list[DetectedFace]) -> np.ndarray:
""" Get the center and set scale of bounding box """ Get the center and set scale of bounding box
Parameters Parameters
@ -95,11 +96,11 @@ class Align(Aligner):
logger.trace("Calculating center and scale") # type:ignore[attr-defined] logger.trace("Calculating center and scale") # type:ignore[attr-defined]
center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32') center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32')
for index, face in enumerate(detected_faces): for index, face in enumerate(detected_faces):
x_center = (cast(int, face.left) + face.right) / 2.0 x_ctr = (T.cast(int, face.left) + face.right) / 2.0
y_center = (cast(int, face.top) + face.bottom) / 2.0 - cast(int, face.height) * 0.12 y_ctr = (T.cast(int, face.top) + face.bottom) / 2.0 - T.cast(int, face.height) * 0.12
scale = (cast(int, face.width) + cast(int, face.height)) * self.reference_scale scale = (T.cast(int, face.width) + T.cast(int, face.height)) * self.reference_scale
center_scale[index, :, 0] = np.full(68, x_center, dtype='float32') center_scale[index, :, 0] = np.full(68, x_ctr, dtype='float32')
center_scale[index, :, 1] = np.full(68, y_center, dtype='float32') center_scale[index, :, 1] = np.full(68, y_ctr, dtype='float32')
center_scale[index, :, 2] = np.full(68, scale, dtype='float32') center_scale[index, :, 2] = np.full(68, scale, dtype='float32')
logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined] logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined]
return center_scale return center_scale
@ -144,7 +145,7 @@ class Align(Aligner):
dsize=(self.input_size, self.input_size), dsize=(self.input_size, self.input_size),
interpolation=interp) interpolation=interp)
def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> List[np.ndarray]: def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> list[np.ndarray]:
""" Crop image around the center point """ Crop image around the center point
Parameters Parameters

View file

@ -15,9 +15,11 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
>>> face = self._to_detected_face(<face left>, <face top>, <face right>, <face bottom>) >>> face = self._to_detected_face(<face left>, <face top>, <face right>, <face bottom>)
""" """
from __future__ import annotations
import logging import logging
import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -30,7 +32,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch from plugins.extract._base import BatchType, Extractor, ExtractorBatch
from plugins.extract.pipeline import ExtractMedia from plugins.extract.pipeline import ExtractMedia
if TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue from queue import Queue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,10 +56,10 @@ class DetectorBatch(ExtractorBatch):
initial_feed: :class:`numpy.ndarray` initial_feed: :class:`numpy.ndarray`
Used to hold the initial :attr:`feed` when rotate images is enabled Used to hold the initial :attr:`feed` when rotate images is enabled
""" """
detected_faces: List[List["DetectedFace"]] = field(default_factory=list) detected_faces: list[list["DetectedFace"]] = field(default_factory=list)
rotation_matrix: List[np.ndarray] = field(default_factory=list) rotation_matrix: list[np.ndarray] = field(default_factory=list)
scale: List[float] = field(default_factory=list) scale: list[float] = field(default_factory=list)
pad: List[Tuple[int, int]] = field(default_factory=list) pad: list[tuple[int, int]] = field(default_factory=list)
initial_feed: np.ndarray = np.array([]) initial_feed: np.ndarray = np.array([])
@ -95,11 +98,11 @@ class Detector(Extractor): # pylint:disable=abstract-method
""" """
def __init__(self, def __init__(self,
git_model_id: Optional[int] = None, git_model_id: int | None = None,
model_filename: Optional[Union[str, List[str]]] = None, model_filename: str | list[str] | None = None,
configfile: Optional[str] = None, configfile: str | None = None,
instance: int = 0, instance: int = 0,
rotation: Optional[str] = None, rotation: str | None = None,
min_size: int = 0, min_size: int = 0,
**kwargs) -> None: **kwargs) -> None:
logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__, logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__,
@ -117,7 +120,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.debug("Initialized _base %s", self.__class__.__name__) logger.debug("Initialized _base %s", self.__class__.__name__)
# <<< QUEUE METHODS >>> # # <<< QUEUE METHODS >>> #
def get_batch(self, queue: "Queue") -> Tuple[bool, DetectorBatch]: def get_batch(self, queue: Queue) -> tuple[bool, DetectorBatch]:
""" Get items for inputting to the detector plugin in batches """ Get items for inputting to the detector plugin in batches
Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted
@ -271,7 +274,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
""" Wrap models predict function in rotations """ """ Wrap models predict function in rotations """
assert isinstance(batch, DetectorBatch) assert isinstance(batch, DetectorBatch)
batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))] batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))]
found_faces: List[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))] found_faces: list[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))]
for angle in self.rotation: for angle in self.rotation:
# Rotate the batch and insert placeholders for already found faces # Rotate the batch and insert placeholders for already found faces
self._rotate_batch(batch, angle) self._rotate_batch(batch, angle)
@ -301,9 +304,9 @@ class Detector(Extractor): # pylint:disable=abstract-method
"degrees", "degrees",
angle) angle)
found_faces = cast(List[np.ndarray], ([face if not found.any() else found found_faces = T.cast(list[np.ndarray], ([face if not found.any() else found
for face, found in zip(batch.prediction, for face, found in zip(batch.prediction,
found_faces)])) found_faces)]))
if all(face.any() for face in found_faces): if all(face.any() for face in found_faces):
logger.trace("Faces found for all images") # type:ignore[attr-defined] logger.trace("Faces found for all images") # type:ignore[attr-defined]
@ -317,7 +320,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
# <<< DETECTION IMAGE COMPILATION METHODS >>> # # <<< DETECTION IMAGE COMPILATION METHODS >>> #
def _compile_detection_image(self, item: ExtractMedia def _compile_detection_image(self, item: ExtractMedia
) -> Tuple[np.ndarray, float, Tuple[int, int]]: ) -> tuple[np.ndarray, float, tuple[int, int]]:
""" Compile the detection image for feeding into the model """ Compile the detection image for feeding into the model
Parameters Parameters
@ -345,7 +348,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
image.shape, scale, pad) image.shape, scale, pad)
return image, scale, pad return image, scale, pad
def _set_scale(self, image_size: Tuple[int, int]) -> float: def _set_scale(self, image_size: tuple[int, int]) -> float:
""" Set the scale factor for incoming image """ Set the scale factor for incoming image
Parameters Parameters
@ -362,7 +365,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined] logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined]
return scale return scale
def _set_padding(self, image_size: Tuple[int, int], scale: float) -> Tuple[int, int]: def _set_padding(self, image_size: tuple[int, int], scale: float) -> tuple[int, int]:
""" Set the image padding for non-square images """ Set the image padding for non-square images
Parameters Parameters
@ -382,7 +385,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
return pad_left, pad_top return pad_left, pad_top
@staticmethod @staticmethod
def _scale_image(image: np.ndarray, image_size: Tuple[int, int], scale: float) -> np.ndarray: def _scale_image(image: np.ndarray, image_size: tuple[int, int], scale: float) -> np.ndarray:
""" Scale the image and optional pad to given size """ Scale the image and optional pad to given size
Parameters Parameters
@ -439,8 +442,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
return image return image
# <<< FINALIZE METHODS >>> # # <<< FINALIZE METHODS >>> #
def _remove_zero_sized_faces(self, batch_faces: List[List[DetectedFace]] def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]]
) -> List[List[DetectedFace]]: ) -> list[list[DetectedFace]]:
""" Remove items from batch_faces where detected face is of zero size or face falls """ Remove items from batch_faces where detected face is of zero size or face falls
entirely outside of image entirely outside of image
@ -463,8 +466,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore
return retval return retval
def _filter_small_faces(self, detected_faces: List[List[DetectedFace]] def _filter_small_faces(self, detected_faces: list[list[DetectedFace]]
) -> List[List[DetectedFace]]: ) -> list[list[DetectedFace]]:
""" Filter out any faces smaller than the min size threshold """ Filter out any faces smaller than the min size threshold
Parameters Parameters
@ -493,7 +496,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
# <<< IMAGE ROTATION METHODS >>> # # <<< IMAGE ROTATION METHODS >>> #
@staticmethod @staticmethod
def _get_rotation_angles(rotation: Optional[str]) -> List[int]: def _get_rotation_angles(rotation: str | None) -> list[int]:
""" Set the rotation angles. """ Set the rotation angles.
Parameters Parameters
@ -544,8 +547,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
batch.initial_feed = batch.feed.copy() batch.initial_feed = batch.feed.copy()
return return
feeds: List[np.ndarray] = [] feeds: list[np.ndarray] = []
rotmats: List[np.ndarray] = [] rotmats: list[np.ndarray] = []
for img, faces, rotmat in zip(batch.initial_feed, for img, faces, rotmat in zip(batch.initial_feed,
batch.prediction, batch.prediction,
batch.rotation_matrix): batch.rotation_matrix):
@ -605,7 +608,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
def _rotate_image_by_angle(self, def _rotate_image_by_angle(self,
image: np.ndarray, image: np.ndarray,
angle: int) -> Tuple[np.ndarray, np.ndarray]: angle: int) -> tuple[np.ndarray, np.ndarray]:
""" Rotate an image by a given angle. """ Rotate an image by a given angle.
Parameters Parameters

View file

@ -34,7 +34,7 @@ class Detect(Detector):
self.kwargs = self._validate_kwargs() self.kwargs = self._validate_kwargs()
self.color_format = "RGB" self.color_format = "RGB"
def _validate_kwargs(self) -> T.Dict[str, T.Union[int, float, T.List[float]]]: def _validate_kwargs(self) -> dict[str, int | float | list[float]]:
""" Validate that config options are correct. If not reset to default """ """ Validate that config options are correct. If not reset to default """
valid = True valid = True
threshold = [self.config["threshold_1"], threshold = [self.config["threshold_1"],
@ -164,7 +164,7 @@ class PNet(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
cpu_mode: bool, cpu_mode: bool,
input_size: int, input_size: int,
min_size: int, min_size: int,
@ -185,10 +185,10 @@ class PNet(KSession):
self._pnet_scales = self._calculate_scales(min_size, factor) self._pnet_scales = self._calculate_scales(min_size, factor)
self._pnet_sizes = [(int(input_size * scale), int(input_size * scale)) self._pnet_sizes = [(int(input_size * scale), int(input_size * scale))
for scale in self._pnet_scales] for scale in self._pnet_scales]
self._pnet_input: T.Optional[T.List[np.ndarray]] = None self._pnet_input: list[np.ndarray] | None = None
@staticmethod @staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras P-Network Definition for MTCNN """ """ Keras P-Network Definition for MTCNN """
input_ = Input(shape=(None, None, 3)) input_ = Input(shape=(None, None, 3))
var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_) var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -204,7 +204,7 @@ class PNet(KSession):
def _calculate_scales(self, def _calculate_scales(self,
minsize: int, minsize: int,
factor: float) -> T.List[float]: factor: float) -> list[float]:
""" Calculate multi-scale """ Calculate multi-scale
Parameters Parameters
@ -231,7 +231,7 @@ class PNet(KSession):
logger.trace(scales) # type:ignore logger.trace(scales) # type:ignore
return scales return scales
def __call__(self, images: np.ndarray) -> T.List[np.ndarray]: def __call__(self, images: np.ndarray) -> list[np.ndarray]:
""" first stage - fast proposal network (p-net) to obtain face candidates """ first stage - fast proposal network (p-net) to obtain face candidates
Parameters Parameters
@ -245,8 +245,8 @@ class PNet(KSession):
List of face candidates from P-Net List of face candidates from P-Net
""" """
batch_size = images.shape[0] batch_size = images.shape[0]
rectangles: T.List[T.List[T.List[T.Union[int, float]]]] = [[] for _ in range(batch_size)] rectangles: list[list[list[int | float]]] = [[] for _ in range(batch_size)]
scores: T.List[T.List[np.ndarray]] = [[] for _ in range(batch_size)] scores: list[list[np.ndarray]] = [[] for _ in range(batch_size)]
if self._pnet_input is None: if self._pnet_input is None:
self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32") self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32")
@ -278,7 +278,7 @@ class PNet(KSession):
class_probabilities: np.ndarray, class_probabilities: np.ndarray,
roi: np.ndarray, roi: np.ndarray,
size: int, size: int,
scale: float) -> T.Tuple[np.ndarray, np.ndarray]: scale: float) -> tuple[np.ndarray, np.ndarray]:
""" Detect face position and calibrate bounding box on 12net feature map(matrix version) """ Detect face position and calibrate bounding box on 12net feature map(matrix version)
Parameters Parameters
@ -344,7 +344,7 @@ class RNet(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
cpu_mode: bool, cpu_mode: bool,
input_size: int, input_size: int,
threshold: float) -> None: threshold: float) -> None:
@ -360,7 +360,7 @@ class RNet(KSession):
self._threshold = threshold self._threshold = threshold
@staticmethod @staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras R-Network Definition for MTCNN """ """ Keras R-Network Definition for MTCNN """
input_ = Input(shape=(24, 24, 3)) input_ = Input(shape=(24, 24, 3))
var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_) var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -383,8 +383,8 @@ class RNet(KSession):
def __call__(self, def __call__(self,
images: np.ndarray, images: np.ndarray,
rectangle_batch: T.List[np.ndarray], rectangle_batch: list[np.ndarray],
) -> T.List[np.ndarray]: ) -> list[np.ndarray]:
""" second stage - refinement of face candidates with r-net """ second stage - refinement of face candidates with r-net
Parameters Parameters
@ -399,7 +399,7 @@ class RNet(KSession):
List List
List of :class:`numpy.ndarray` refined face candidates from R-Net List of :class:`numpy.ndarray` refined face candidates from R-Net
""" """
ret: T.List[np.ndarray] = [] ret: list[np.ndarray] = []
for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)): for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)):
if not np.any(rectangles): if not np.any(rectangles):
ret.append(np.array([])) ret.append(np.array([]))
@ -474,7 +474,7 @@ class ONet(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
cpu_mode: bool, cpu_mode: bool,
input_size: int, input_size: int,
threshold: float) -> None: threshold: float) -> None:
@ -490,7 +490,7 @@ class ONet(KSession):
self._threshold = threshold self._threshold = threshold
@staticmethod @staticmethod
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: def model_definition() -> tuple[list[Tensor], list[Tensor]]:
""" Keras O-Network for MTCNN """ """ Keras O-Network for MTCNN """
input_ = Input(shape=(48, 48, 3)) input_ = Input(shape=(48, 48, 3))
var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_) var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_)
@ -516,8 +516,8 @@ class ONet(KSession):
def __call__(self, def __call__(self,
images: np.ndarray, images: np.ndarray,
rectangle_batch: T.List[np.ndarray] rectangle_batch: list[np.ndarray]
) -> T.List[T.Tuple[np.ndarray, np.ndarray]]: ) -> list[tuple[np.ndarray, np.ndarray]]:
""" Third stage - further refinement and facial landmarks positions with o-net """ Third stage - further refinement and facial landmarks positions with o-net
Parameters Parameters
@ -532,7 +532,7 @@ class ONet(KSession):
List List
List of refined final candidates, scores and landmark points from O-Net List of refined final candidates, scores and landmark points from O-Net
""" """
ret: T.List[T.Tuple[np.ndarray, np.ndarray]] = [] ret: list[tuple[np.ndarray, np.ndarray]] = []
for idx, rectangles in enumerate(rectangle_batch): for idx, rectangles in enumerate(rectangle_batch):
if not np.any(rectangles): if not np.any(rectangles):
ret.append((np.empty((0, 5)), np.empty(0))) ret.append((np.empty((0, 5)), np.empty(0)))
@ -552,7 +552,7 @@ class ONet(KSession):
def _filter_face_48net(self, class_probabilities: np.ndarray, def _filter_face_48net(self, class_probabilities: np.ndarray,
roi: np.ndarray, roi: np.ndarray,
points: np.ndarray, points: np.ndarray,
rectangles: np.ndarray) -> T.Tuple[np.ndarray, np.ndarray]: rectangles: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""" Filter face position and calibrate bounding box on 12net's output """ Filter face position and calibrate bounding box on 12net's output
Parameters Parameters
@ -623,13 +623,13 @@ class MTCNN(): # pylint: disable=too-few-public-methods
Default: `0.709` Default: `0.709`
""" """
def __init__(self, def __init__(self,
model_path: T.List[str], model_path: list[str],
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
cpu_mode: bool, cpu_mode: bool,
input_size: int = 640, input_size: int = 640,
minsize: int = 20, minsize: int = 20,
threshold: T.Optional[T.List[float]] = None, threshold: list[float] | None = None,
factor: float = 0.709) -> None: factor: float = 0.709) -> None:
logger.debug("Initializing: %s: (model_path: '%s', allow_growth: %s, exclude_gpus: %s, " logger.debug("Initializing: %s: (model_path: '%s', allow_growth: %s, exclude_gpus: %s, "
"input_size: %s, minsize: %s, threshold: %s, factor: %s)", "input_size: %s, minsize: %s, threshold: %s, factor: %s)",
@ -660,7 +660,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
def detect_faces(self, batch: np.ndarray) -> T.Tuple[np.ndarray, T.Tuple[np.ndarray]]: def detect_faces(self, batch: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray]]:
"""Detects faces in an image, and returns bounding boxes and points for them. """Detects faces in an image, and returns bounding boxes and points for them.
Parameters Parameters
@ -684,7 +684,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
def nms(rectangles: np.ndarray, def nms(rectangles: np.ndarray,
scores: np.ndarray, scores: np.ndarray,
threshold: float, threshold: float,
method: str = "iom") -> T.Tuple[np.ndarray, np.ndarray]: method: str = "iom") -> tuple[np.ndarray, np.ndarray]:
""" apply non-maximum suppression on ROIs in same scale(matrix version) """ apply non-maximum suppression on ROIs in same scale(matrix version)
Parameters Parameters

View file

@ -125,10 +125,10 @@ class L2Norm(keras.layers.Layer):
class SliceO2K(keras.layers.Layer): class SliceO2K(keras.layers.Layer):
""" Custom Keras Slice layer generated by onnx2keras. """ """ Custom Keras Slice layer generated by onnx2keras. """
def __init__(self, def __init__(self,
starts: T.List[int], starts: list[int],
ends: T.List[int], ends: list[int],
axes: T.Optional[T.List[int]] = None, axes: list[int] | None = None,
steps: T.Optional[T.List[int]] = None, steps: list[int] | None = None,
**kwargs) -> None: **kwargs) -> None:
self._starts = starts self._starts = starts
self._ends = ends self._ends = ends
@ -136,7 +136,7 @@ class SliceO2K(keras.layers.Layer):
self._steps = steps self._steps = steps
super().__init__(**kwargs) super().__init__(**kwargs)
def _get_slices(self, dimensions: int) -> T.List[T.Tuple[int, ...]]: def _get_slices(self, dimensions: int) -> list[tuple[int, ...]]:
""" Obtain slices for the given number of dimensions. """ Obtain slices for the given number of dimensions.
Parameters Parameters
@ -154,7 +154,7 @@ class SliceO2K(keras.layers.Layer):
assert len(axes) == len(steps) == len(self._starts) == len(self._ends) assert len(axes) == len(steps) == len(self._starts) == len(self._ends)
return list(zip(axes, self._starts, self._ends, steps)) return list(zip(axes, self._starts, self._ends, steps))
def compute_output_shape(self, input_shape: T.Tuple[int, ...]) -> T.Tuple[int, ...]: def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the output shape of the layer. """Computes the output shape of the layer.
Assumes that the layer will be built to match that input shape provided. Assumes that the layer will be built to match that input shape provided.
@ -230,7 +230,7 @@ class S3fd(KSession):
model_path: str, model_path: str,
model_kwargs: dict, model_kwargs: dict,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
confidence: float) -> None: confidence: float) -> None:
logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, " logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, "
"exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path, "exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path,
@ -246,7 +246,7 @@ class S3fd(KSession):
self.average_img = np.array([104.0, 117.0, 123.0]) self.average_img = np.array([104.0, 117.0, 123.0])
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
def model_definition(self) -> T.Tuple[T.List[Tensor], T.List[Tensor]]: def model_definition(self) -> tuple[list[Tensor], list[Tensor]]:
""" Keras S3FD Model Definition, adapted from FAN pytorch implementation. """ """ Keras S3FD Model Definition, adapted from FAN pytorch implementation. """
input_ = Input(shape=(640, 640, 3)) input_ = Input(shape=(640, 640, 3))
var_x = self.conv_block(input_, 64, 1, 2) var_x = self.conv_block(input_, 64, 1, 2)
@ -396,7 +396,7 @@ class S3fd(KSession):
batch = batch - self.average_img batch = batch - self.average_img
return batch return batch
def finalize_predictions(self, bounding_boxes_scales: T.List[np.ndarray]) -> np.ndarray: def finalize_predictions(self, bounding_boxes_scales: list[np.ndarray]) -> np.ndarray:
""" Process the output from the model to obtain faces """ Process the output from the model to obtain faces
Parameters Parameters
@ -413,7 +413,7 @@ class S3fd(KSession):
ret.append(finallist) ret.append(finallist)
return np.array(ret, dtype="object") return np.array(ret, dtype="object")
def _post_process(self, bboxlist: T.List[np.ndarray]) -> np.ndarray: def _post_process(self, bboxlist: list[np.ndarray]) -> np.ndarray:
""" Perform post processing on output """ Perform post processing on output
TODO: do this on the batch. TODO: do this on the batch.
""" """

View file

@ -12,9 +12,11 @@ For each source item, the plugin must pass a dict to finalize containing:
>>> {"filename": <filename of source frame>, >>> {"filename": <filename of source frame>,
>>> "detected_faces": <list of bounding box dicts from lib/plugins/extract/detect/_base>} >>> "detected_faces": <list of bounding box dicts from lib/plugins/extract/detect/_base>}
""" """
from __future__ import annotations
import logging import logging
import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generator, List, Optional, Tuple, TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
@ -25,7 +27,8 @@ from lib.align import AlignedFace, transform_image
from lib.utils import FaceswapError from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia
if TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue from queue import Queue
from lib.align import DetectedFace from lib.align import DetectedFace
from lib.align.aligned_face import CenteringType from lib.align.aligned_face import CenteringType
@ -44,9 +47,9 @@ class MaskerBatch(ExtractorBatch):
roi_masks: list roi_masks: list
The region of interest masks for the batch The region of interest masks for the batch
""" """
detected_faces: List["DetectedFace"] = field(default_factory=list) detected_faces: list[DetectedFace] = field(default_factory=list)
roi_masks: List[np.ndarray] = field(default_factory=list) roi_masks: list[np.ndarray] = field(default_factory=list)
feed_faces: List[AlignedFace] = field(default_factory=list) feed_faces: list[AlignedFace] = field(default_factory=list)
class Masker(Extractor): # pylint:disable=abstract-method class Masker(Extractor): # pylint:disable=abstract-method
@ -77,9 +80,9 @@ class Masker(Extractor): # pylint:disable=abstract-method
""" """
def __init__(self, def __init__(self,
git_model_id: Optional[int] = None, git_model_id: int | None = None,
model_filename: Optional[str] = None, model_filename: str | None = None,
configfile: Optional[str] = None, configfile: str | None = None,
instance: int = 0, instance: int = 0,
**kwargs) -> None: **kwargs) -> None:
logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile) logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile)
@ -93,11 +96,11 @@ class Masker(Extractor): # pylint:disable=abstract-method
self._plugin_type = "mask" self._plugin_type = "mask"
self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-") self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-")
self._storage_centering: "CenteringType" = "face" # Centering to store the mask at self._storage_centering: CenteringType = "face" # Centering to store the mask at
self._storage_size = 128 # Size to store masks at. Leave this at default self._storage_size = 128 # Size to store masks at. Leave this at default
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def get_batch(self, queue: "Queue") -> Tuple[bool, MaskerBatch]: def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]:
""" Get items for inputting into the masker from the queue in batches """ Get items for inputting into the masker from the queue in batches
Items are returned from the ``queue`` in batches of Items are returned from the ``queue`` in batches of

View file

@ -49,7 +49,7 @@ class Mask(Masker):
# Separate storage for face and head masks # Separate storage for face and head masks
self._storage_name = f"{self._storage_name}_{self._storage_centering}" self._storage_name = f"{self._storage_name}_{self._storage_centering}"
def _check_weights_selection(self, configfile: T.Optional[str]) -> T.Tuple[bool, int]: def _check_weights_selection(self, configfile: str | None) -> tuple[bool, int]:
""" Check which weights have been selected. """ Check which weights have been selected.
This is required for passing along the correct file name for the corresponding weights This is required for passing along the correct file name for the corresponding weights
@ -73,7 +73,7 @@ class Mask(Masker):
version = 1 if not is_faceswap else 2 if config.get("include_hair") else 3 version = 1 if not is_faceswap else 2 if config.get("include_hair") else 3
return is_faceswap, version return is_faceswap, version
def _get_segment_indices(self) -> T.List[int]: def _get_segment_indices(self) -> list[int]:
""" Obtain the segment indices to include within the face mask area based on user """ Obtain the segment indices to include within the face mask area based on user
configuration settings. configuration settings.
@ -163,7 +163,7 @@ class Mask(Masker):
# SOFTWARE. # SOFTWARE.
_NAME_TRACKER: T.Set[str] = set() _NAME_TRACKER: set[str] = set()
def _get_name(name: str, start_idx: int = 1) -> str: def _get_name(name: str, start_idx: int = 1) -> str:
@ -554,7 +554,7 @@ class BiSeNet(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]], exclude_gpus: list[int] | None,
input_size: int, input_size: int,
num_classes: int, num_classes: int,
cpu_mode: bool) -> None: cpu_mode: bool) -> None:
@ -569,7 +569,7 @@ class BiSeNet(KSession):
self.define_model(self._model_definition) self.define_model(self._model_definition)
self.load_model_weights() self.load_model_weights()
def _model_definition(self) -> T.Tuple[Tensor, T.List[Tensor]]: def _model_definition(self) -> tuple[Tensor, list[Tensor]]:
""" Definition of the VGG Obstructed Model. """ Definition of the VGG Obstructed Model.
Returns Returns

View file

@ -1,14 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Components Mask for faceswap.py """ """ Components Mask for faceswap.py """
from __future__ import annotations
import logging import logging
from typing import List, Tuple, TYPE_CHECKING import typing as T
import cv2 import cv2
import numpy as np import numpy as np
from ._base import BatchType, Masker from ._base import BatchType, Masker
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align.aligned_face import AlignedFace from lib.align.aligned_face import AlignedFace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +37,7 @@ class Mask(Masker):
def predict(self, feed: np.ndarray) -> np.ndarray: def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """ """ Run model to get predictions """
faces: List["AlignedFace"] = feed[1] faces: list[AlignedFace] = feed[1]
feed = feed[0] feed = feed[0]
for mask, face in zip(feed, faces): for mask, face in zip(feed, faces):
parts = self.parse_parts(np.array(face.landmarks)) parts = self.parse_parts(np.array(face.landmarks))
@ -51,7 +52,7 @@ class Mask(Masker):
return return
@staticmethod @staticmethod
def parse_parts(landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]: def parse_parts(landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
""" Component face hull mask """ """ Component face hull mask """
r_jaw = (landmarks[0:9], landmarks[17:18]) r_jaw = (landmarks[0:9], landmarks[17:18])
l_jaw = (landmarks[8:17], landmarks[26:27]) l_jaw = (landmarks[8:17], landmarks[26:27])

View file

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Extended Mask for faceswap.py """ """ Extended Mask for faceswap.py """
from __future__ import annotations
import logging import logging
from typing import List, Tuple, TYPE_CHECKING import typing as T
import cv2 import cv2
import numpy as np import numpy as np
@ -9,7 +10,7 @@ from ._base import BatchType, Masker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align.aligned_face import AlignedFace from lib.align.aligned_face import AlignedFace
@ -35,7 +36,7 @@ class Mask(Masker):
def predict(self, feed: np.ndarray) -> np.ndarray: def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """ """ Run model to get predictions """
faces: List["AlignedFace"] = feed[1] faces: list[AlignedFace] = feed[1]
feed = feed[0] feed = feed[0]
for mask, face in zip(feed, faces): for mask, face in zip(feed, faces):
parts = self.parse_parts(np.array(face.landmarks)) parts = self.parse_parts(np.array(face.landmarks))
@ -78,7 +79,7 @@ class Mask(Masker):
landmarks[17:22] = top_l + ((top_l - bot_l) // 2) landmarks[17:22] = top_l + ((top_l - bot_l) // 2)
landmarks[22:27] = top_r + ((top_r - bot_r) // 2) landmarks[22:27] = top_r + ((top_r - bot_r) // 2)
def parse_parts(self, landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]: def parse_parts(self, landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
""" Extended face hull mask """ """ Extended face hull mask """
self._adjust_mask_top(landmarks) self._adjust_mask_top(landmarks)

View file

@ -13,7 +13,7 @@ Model file sourced from...
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5 https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
""" """
import logging import logging
from typing import cast import typing as T
import numpy as np import numpy as np
from lib.model.session import KSession from lib.model.session import KSession
@ -52,7 +52,7 @@ class Mask(Masker):
def process_input(self, batch: BatchType) -> None: def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """ """ Compile the detected faces for prediction """
assert isinstance(batch, MaskerBatch) assert isinstance(batch, MaskerBatch)
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3] batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
for feed in batch.feed_faces], dtype="float32") / 255.0 for feed in batch.feed_faces], dtype="float32") / 255.0
logger.trace("feed shape: %s", batch.feed.shape) # type: ignore logger.trace("feed shape: %s", batch.feed.shape) # type: ignore

View file

@ -94,7 +94,7 @@ class VGGClear(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]]): exclude_gpus: list[int] | None):
super().__init__("VGG Obstructed", super().__init__("VGG Obstructed",
model_path, model_path,
allow_growth=allow_growth, allow_growth=allow_growth,
@ -103,7 +103,7 @@ class VGGClear(KSession):
self.load_model_weights() self.load_model_weights()
@classmethod @classmethod
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]: def _model_definition(cls) -> tuple[Tensor, Tensor]:
""" Definition of the VGG Obstructed Model. """ Definition of the VGG Obstructed Model.
Returns Returns
@ -210,7 +210,7 @@ class _ScorePool(): # pylint:disable=too-few-public-methods
crop: tuple crop: tuple
The amount of 2D cropping to apply. Tuple of `ints` The amount of 2D cropping to apply. Tuple of `ints`
""" """
def __init__(self, level: int, scale: float, crop: T.Tuple[int, int]): def __init__(self, level: int, scale: float, crop: tuple[int, int]):
self._name = f"_pool{level}" self._name = f"_pool{level}"
self._cropping = (crop, crop) self._cropping = (crop, crop)
self._scale = scale self._scale = scale

View file

@ -90,7 +90,7 @@ class VGGObstructed(KSession):
def __init__(self, def __init__(self,
model_path: str, model_path: str,
allow_growth: bool, allow_growth: bool,
exclude_gpus: T.Optional[T.List[int]]) -> None: exclude_gpus: list[int] | None) -> None:
super().__init__("VGG Obstructed", super().__init__("VGG Obstructed",
model_path, model_path,
allow_growth=allow_growth, allow_growth=allow_growth,
@ -99,7 +99,7 @@ class VGGObstructed(KSession):
self.load_model_weights() self.load_model_weights()
@classmethod @classmethod
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]: def _model_definition(cls) -> tuple[Tensor, Tensor]:
""" Definition of the VGG Obstructed Model. """ Definition of the VGG Obstructed Model.
Returns Returns

View file

@ -8,10 +8,9 @@ together.
This module sets up a pipeline for the extraction workflow, loading detect, align and mask This module sets up a pipeline for the extraction workflow, loading detect, align and mask
plugins either in parallel or in series, giving easy access to input and output. plugins either in parallel or in series, giving easy access to input and output.
""" """
from __future__ import annotations
import logging import logging
import sys import typing as T
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
import cv2 import cv2
@ -20,13 +19,9 @@ from lib.queue_manager import EventQueue, queue_manager, QueueEmpty
from lib.utils import get_backend from lib.utils import get_backend
from plugins.plugin_loader import PluginLoader from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8): if T.TYPE_CHECKING:
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
import numpy as np import numpy as np
from collections.abc import Generator
from lib.align.alignments import PNGHeaderSourceDict from lib.align.alignments import PNGHeaderSourceDict
from lib.align.detected_face import DetectedFace from lib.align.detected_face import DetectedFace
from plugins.extract._base import Extractor as PluginExtractor from plugins.extract._base import Extractor as PluginExtractor
@ -102,16 +97,16 @@ class Extractor():
:attr:`final_pass` to indicate to the caller which phase is being processed :attr:`final_pass` to indicate to the caller which phase is being processed
""" """
def __init__(self, def __init__(self,
detector: Optional[str], detector: str | None,
aligner: Optional[str], aligner: str | None,
masker: Optional[Union[str, List[str]]], masker: str | list[str] | None,
recognition: Optional[str] = None, recognition: str | None = None,
configfile: Optional[str] = None, configfile: str | None = None,
multiprocess: bool = False, multiprocess: bool = False,
exclude_gpus: Optional[List[int]] = None, exclude_gpus: list[int] | None = None,
rotate_images: Optional[str] = None, rotate_images: str | None = None,
min_size: int = 0, min_size: int = 0,
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None, normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
re_feed: int = 0, re_feed: int = 0,
re_align: bool = False, re_align: bool = False,
disable_filter: bool = False) -> None: disable_filter: bool = False) -> None:
@ -122,8 +117,9 @@ class Extractor():
recognition, configfile, multiprocess, exclude_gpus, rotate_images, min_size, recognition, configfile, multiprocess, exclude_gpus, rotate_images, min_size,
normalize_method, re_feed, re_align, disable_filter) normalize_method, re_feed, re_align, disable_filter)
self._instance = _get_instance() self._instance = _get_instance()
maskers = [cast(Optional[str], maskers = [T.cast(str | None,
masker)] if not isinstance(masker, list) else cast(List[Optional[str]], masker) masker)] if not isinstance(masker, list) else T.cast(list[str | None],
masker)
self._flow = self._set_flow(detector, aligner, maskers, recognition) self._flow = self._set_flow(detector, aligner, maskers, recognition)
self._exclude_gpus = exclude_gpus self._exclude_gpus = exclude_gpus
# We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting # We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting
@ -220,13 +216,13 @@ class Extractor():
return retval return retval
@property @property
def aligner(self) -> "Aligner": def aligner(self) -> Aligner:
""" The currently selected aligner plugin """ """ The currently selected aligner plugin """
assert self._align is not None assert self._align is not None
return self._align return self._align
@property @property
def recognition(self) -> "Identity": def recognition(self) -> Identity:
""" The currently selected recognition plugin """ """ The currently selected recognition plugin """
assert self._recognition is not None assert self._recognition is not None
return self._recognition return self._recognition
@ -237,7 +233,7 @@ class Extractor():
self._phase_index = 0 self._phase_index = 0
def set_batchsize(self, def set_batchsize(self,
plugin_type: Literal["align", "detect"], plugin_type: T.Literal["align", "detect"],
batchsize: int) -> None: batchsize: int) -> None:
""" Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`. """ Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`.
@ -311,7 +307,7 @@ class Extractor():
# <<< INTERNAL METHODS >>> # # <<< INTERNAL METHODS >>> #
@property @property
def _parallel_scaling(self) -> Dict[int, float]: def _parallel_scaling(self) -> dict[int, float]:
""" dict: key is number of parallel plugins being loaded, value is the scaling factor that """ dict: key is number of parallel plugins being loaded, value is the scaling factor that
the total base vram for those plugins should be scaled by the total base vram for those plugins should be scaled by
@ -335,7 +331,7 @@ class Extractor():
return retval return retval
@property @property
def _vram_per_phase(self) -> Dict[str, float]: def _vram_per_phase(self) -> dict[str, float]:
""" dict: The amount of vram required for each phase in :attr:`_flow`. """ """ dict: The amount of vram required for each phase in :attr:`_flow`. """
retval = {} retval = {}
for phase in self._flow: for phase in self._flow:
@ -359,7 +355,7 @@ class Extractor():
return retval return retval
@property @property
def _current_phase(self) -> List[str]: def _current_phase(self) -> list[str]:
""" list: The current phase from :attr:`_phases` that is running through the extractor. """ """ list: The current phase from :attr:`_phases` that is running through the extractor. """
retval = self._phases[self._phase_index] retval = self._phases[self._phase_index]
logger.trace(retval) # type: ignore logger.trace(retval) # type: ignore
@ -384,7 +380,7 @@ class Extractor():
return retval return retval
@property @property
def _all_plugins(self) -> List["PluginExtractor"]: def _all_plugins(self) -> list[PluginExtractor]:
""" Return list of all plugin objects in this pipeline """ """ Return list of all plugin objects in this pipeline """
retval = [] retval = []
for phase in self._flow: for phase in self._flow:
@ -396,7 +392,7 @@ class Extractor():
return retval return retval
@property @property
def _active_plugins(self) -> List["PluginExtractor"]: def _active_plugins(self) -> list[PluginExtractor]:
""" Return the plugins that are currently active based on pass """ """ Return the plugins that are currently active based on pass """
retval = [] retval = []
for phase in self._current_phase: for phase in self._current_phase:
@ -407,10 +403,10 @@ class Extractor():
return retval return retval
@staticmethod @staticmethod
def _set_flow(detector: Optional[str], def _set_flow(detector: str | None,
aligner: Optional[str], aligner: str | None,
masker: List[Optional[str]], masker: list[str | None],
recognition: Optional[str]) -> List[str]: recognition: str | None) -> list[str]:
""" Set the flow list based on the input plugins """ Set the flow list based on the input plugins
Parameters Parameters
@ -441,7 +437,7 @@ class Extractor():
return retval return retval
@staticmethod @staticmethod
def _get_plugin_type_and_index(flow_phase: str) -> Tuple[str, Optional[int]]: def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]:
""" Obtain the plugin type and index for the plugin for the given flow phase. """ Obtain the plugin type and index for the plugin for the given flow phase.
When multiple plugins for the same phase are allowed (e.g. Mask) this will return When multiple plugins for the same phase are allowed (e.g. Mask) this will return
@ -463,14 +459,14 @@ class Extractor():
""" """
sidx = flow_phase.split("_")[-1] sidx = flow_phase.split("_")[-1]
if sidx.isdigit(): if sidx.isdigit():
idx: Optional[int] = int(sidx) idx: int | None = int(sidx)
plugin_type = "_".join(flow_phase.split("_")[:-1]) plugin_type = "_".join(flow_phase.split("_")[:-1])
else: else:
plugin_type = flow_phase plugin_type = flow_phase
idx = None idx = None
return plugin_type, idx return plugin_type, idx
def _add_queues(self) -> Dict[str, EventQueue]: def _add_queues(self) -> dict[str, EventQueue]:
""" Add the required processing queues to Queue Manager """ """ Add the required processing queues to Queue Manager """
queues = {} queues = {}
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow] tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
@ -483,7 +479,7 @@ class Extractor():
return queues return queues
@staticmethod @staticmethod
def _get_vram_stats() -> Dict[str, Union[int, str]]: def _get_vram_stats() -> dict[str, int | str]:
""" Obtain statistics on available VRAM and subtract a constant buffer from available vram. """ Obtain statistics on available VRAM and subtract a constant buffer from available vram.
Returns Returns
@ -494,10 +490,10 @@ class Extractor():
vram_buffer = 256 # Leave a buffer for VRAM allocation vram_buffer = 256 # Leave a buffer for VRAM allocation
gpu_stats = GPUStats() gpu_stats = GPUStats()
stats = gpu_stats.get_card_most_free() stats = gpu_stats.get_card_most_free()
retval: Dict[str, Union[int, str]] = {"count": gpu_stats.device_count, retval: dict[str, int | str] = {"count": gpu_stats.device_count,
"device": stats.device, "device": stats.device,
"vram_free": int(stats.free - vram_buffer), "vram_free": int(stats.free - vram_buffer),
"vram_total": int(stats.total)} "vram_total": int(stats.total)}
logger.debug(retval) logger.debug(retval)
return retval return retval
@ -521,13 +517,13 @@ class Extractor():
self._vram_stats["device"], self._vram_stats["device"],
self._vram_stats["vram_free"], self._vram_stats["vram_free"],
self._vram_stats["vram_total"]) self._vram_stats["vram_total"])
if cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required: if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
logger.warning("Not enough free VRAM for parallel processing. " logger.warning("Not enough free VRAM for parallel processing. "
"Switching to serial") "Switching to serial")
return False return False
return True return True
def _set_phases(self, multiprocess: bool) -> List[List[str]]: def _set_phases(self, multiprocess: bool) -> list[list[str]]:
""" If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit """ If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit
into VRAM, otherwise return the single flow. into VRAM, otherwise return the single flow.
@ -541,9 +537,9 @@ class Extractor():
list: list:
The jobs to be undertaken split into phases that fit into GPU RAM The jobs to be undertaken split into phases that fit into GPU RAM
""" """
phases: List[List[str]] = [] phases: list[list[str]] = []
current_phase: List[str] = [] current_phase: list[str] = []
available = cast(int, self._vram_stats["vram_free"]) available = T.cast(int, self._vram_stats["vram_free"])
for phase in self._flow: for phase in self._flow:
num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0]) num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0])
num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0 num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0
@ -576,12 +572,12 @@ class Extractor():
# << INTERNAL PLUGIN HANDLING >> # # << INTERNAL PLUGIN HANDLING >> #
def _load_align(self, def _load_align(self,
aligner: Optional[str], aligner: str | None,
configfile: Optional[str], configfile: str | None,
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]], normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None,
re_feed: int, re_feed: int,
re_align: bool, re_align: bool,
disable_filter: bool) -> Optional["Aligner"]: disable_filter: bool) -> Aligner | None:
""" Set global arguments and load aligner plugin """ Set global arguments and load aligner plugin
Parameters Parameters
@ -619,10 +615,10 @@ class Extractor():
return plugin return plugin
def _load_detect(self, def _load_detect(self,
detector: Optional[str], detector: str | None,
rotation: Optional[str], rotation: str | None,
min_size: int, min_size: int,
configfile: Optional[str]) -> Optional["Detector"]: configfile: str | None) -> Detector | None:
""" Set global arguments and load detector plugin """ """ Set global arguments and load detector plugin """
if detector is None or detector.lower() == "none": if detector is None or detector.lower() == "none":
logger.debug("No detector selected. Returning None") logger.debug("No detector selected. Returning None")
@ -637,8 +633,8 @@ class Extractor():
return plugin return plugin
def _load_mask(self, def _load_mask(self,
masker: Optional[str], masker: str | None,
configfile: Optional[str]) -> Optional["Masker"]: configfile: str | None) -> Masker | None:
""" Set global arguments and load masker plugin """ Set global arguments and load masker plugin
Parameters Parameters
@ -664,8 +660,8 @@ class Extractor():
return plugin return plugin
def _load_recognition(self, def _load_recognition(self,
recognition: Optional[str], recognition: str | None,
configfile: Optional[str]) -> Optional["Identity"]: configfile: str | None) -> Identity | None:
""" Set global arguments and load recognition plugin """ """ Set global arguments and load recognition plugin """
if recognition is None or recognition.lower() == "none": if recognition is None or recognition.lower() == "none":
logger.debug("No recognition selected. Returning None") logger.debug("No recognition selected. Returning None")
@ -716,16 +712,16 @@ class Extractor():
gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0] gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0]
scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback) scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback)
plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling
if plugins_required + batch_required <= cast(int, self._vram_stats["vram_free"]): if plugins_required + batch_required <= T.cast(int, self._vram_stats["vram_free"]):
logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, " logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, "
"vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"]) "vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"])
return return
# Hacky split across plugins that use vram # Hacky split across plugins that use vram
available_vram = (cast(int, self._vram_stats["vram_free"]) available_vram = (T.cast(int, self._vram_stats["vram_free"])
- plugins_required) // len(gpu_plugins) - plugins_required) // len(gpu_plugins)
self._set_plugin_batchsize(gpu_plugins, available_vram) self._set_plugin_batchsize(gpu_plugins, available_vram)
def _set_plugin_batchsize(self, gpu_plugins: List[str], available_vram: float) -> None: def _set_plugin_batchsize(self, gpu_plugins: list[str], available_vram: float) -> None:
""" Set the batch size for the given plugin based on given available vram. """ Set the batch size for the given plugin based on given available vram.
Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to
zero division error. zero division error.
@ -802,20 +798,20 @@ class ExtractMedia():
def __init__(self, def __init__(self,
filename: str, filename: str,
image: "np.ndarray", image: np.ndarray,
detected_faces: Optional[List["DetectedFace"]] = None, detected_faces: list[DetectedFace] | None = None,
is_aligned: bool = False) -> None: is_aligned: bool = False) -> None:
logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore
"detected_faces: %s, is_aligned: %s)", self.__class__.__name__, filename, "detected_faces: %s, is_aligned: %s)", self.__class__.__name__, filename,
image.shape, detected_faces, is_aligned) image.shape, detected_faces, is_aligned)
self._filename = filename self._filename = filename
self._image: Optional["np.ndarray"] = image self._image: np.ndarray | None = image
self._image_shape = cast(Tuple[int, int, int], image.shape) self._image_shape = T.cast(tuple[int, int, int], image.shape)
self._detected_faces: List["DetectedFace"] = ([] if detected_faces is None self._detected_faces: list[DetectedFace] = ([] if detected_faces is None
else detected_faces) else detected_faces)
self._is_aligned = is_aligned self._is_aligned = is_aligned
self._frame_metadata: Optional["PNGHeaderSourceDict"] = None self._frame_metadata: PNGHeaderSourceDict | None = None
self._sub_folders: List[Optional[str]] = [] self._sub_folders: list[str | None] = []
@property @property
def filename(self) -> str: def filename(self) -> str:
@ -823,23 +819,23 @@ class ExtractMedia():
return self._filename return self._filename
@property @property
def image(self) -> "np.ndarray": def image(self) -> np.ndarray:
""" :class:`numpy.ndarray`: The source frame for this object. """ """ :class:`numpy.ndarray`: The source frame for this object. """
assert self._image is not None assert self._image is not None
return self._image return self._image
@property @property
def image_shape(self) -> Tuple[int, int, int]: def image_shape(self) -> tuple[int, int, int]:
""" tuple: The shape of the stored :attr:`image`. """ """ tuple: The shape of the stored :attr:`image`. """
return self._image_shape return self._image_shape
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> tuple[int, int]:
""" tuple: The (`height`, `width`) of the stored :attr:`image`. """ """ tuple: The (`height`, `width`) of the stored :attr:`image`. """
return self._image_shape[:2] return self._image_shape[:2]
@property @property
def detected_faces(self) -> List["DetectedFace"]: def detected_faces(self) -> list[DetectedFace]:
"""list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """ """list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """
return self._detected_faces return self._detected_faces
@ -849,7 +845,7 @@ class ExtractMedia():
return self._is_aligned return self._is_aligned
@property @property
def frame_metadata(self) -> "PNGHeaderSourceDict": def frame_metadata(self) -> PNGHeaderSourceDict:
""" dict: The frame metadata that has been added from an aligned image. This property """ dict: The frame metadata that has been added from an aligned image. This property
should only be called after :func:`add_frame_metadata` has been called when processing should only be called after :func:`add_frame_metadata` has been called when processing
an aligned face. For all other instances an assertion error will be raised. an aligned face. For all other instances an assertion error will be raised.
@ -863,13 +859,13 @@ class ExtractMedia():
return self._frame_metadata return self._frame_metadata
@property @property
def sub_folders(self) -> List[Optional[str]]: def sub_folders(self) -> list[str | None]:
""" list: The sub_folders that the faces should be output to. Used when binning filter """ list: The sub_folders that the faces should be output to. Used when binning filter
output is enabled. The list corresponds to the list of detected faces output is enabled. The list corresponds to the list of detected faces
""" """
return self._sub_folders return self._sub_folders
def get_image_copy(self, color_format: Literal["BGR", "RGB", "GRAY"]) -> "np.ndarray": def get_image_copy(self, color_format: T.Literal["BGR", "RGB", "GRAY"]) -> np.ndarray:
""" Get a copy of the image in the requested color format. """ Get a copy of the image in the requested color format.
Parameters Parameters
@ -887,7 +883,7 @@ class ExtractMedia():
image = getattr(self, f"_image_as_{color_format.lower()}")() image = getattr(self, f"_image_as_{color_format.lower()}")()
return image return image
def add_detected_faces(self, faces: List["DetectedFace"]) -> None: def add_detected_faces(self, faces: list[DetectedFace]) -> None:
""" Add detected faces to the object. Called at the end of each extraction phase. """ Add detected faces to the object. Called at the end of each extraction phase.
Parameters Parameters
@ -900,7 +896,7 @@ class ExtractMedia():
[(face.left, face.right, face.top, face.bottom) for face in faces]) [(face.left, face.right, face.top, face.bottom) for face in faces])
self._detected_faces = faces self._detected_faces = faces
def add_sub_folders(self, folders: List[Optional[str]]) -> None: def add_sub_folders(self, folders: list[str | None]) -> None:
""" Add detected faces to the object. Called at the end of each extraction phase. """ Add detected faces to the object. Called at the end of each extraction phase.
Parameters Parameters
@ -922,7 +918,7 @@ class ExtractMedia():
del self._image del self._image
self._image = None self._image = None
def set_image(self, image: "np.ndarray") -> None: def set_image(self, image: np.ndarray) -> None:
""" Add the image back into :attr:`image` """ Add the image back into :attr:`image`
Required for multi-phase extraction adds the image back to this object. Required for multi-phase extraction adds the image back to this object.
@ -936,7 +932,7 @@ class ExtractMedia():
self._filename, image.shape) self._filename, image.shape)
self._image = image self._image = image
def add_frame_metadata(self, metadata: "PNGHeaderSourceDict") -> None: def add_frame_metadata(self, metadata: PNGHeaderSourceDict) -> None:
""" Add the source frame metadata from an aligned PNG's header data. """ Add the source frame metadata from an aligned PNG's header data.
metadata: dict metadata: dict
@ -944,11 +940,11 @@ class ExtractMedia():
""" """
logger.trace("Adding PNG Source data for '%s': %s", # type:ignore logger.trace("Adding PNG Source data for '%s': %s", # type:ignore
self._filename, metadata) self._filename, metadata)
dims = cast(Tuple[int, int], metadata["source_frame_dims"]) dims = T.cast(tuple[int, int], metadata["source_frame_dims"])
self._image_shape = (*dims, 3) self._image_shape = (*dims, 3)
self._frame_metadata = metadata self._frame_metadata = metadata
def _image_as_bgr(self) -> "np.ndarray": def _image_as_bgr(self) -> np.ndarray:
""" Get a copy of the source frame in BGR format. """ Get a copy of the source frame in BGR format.
Returns Returns
@ -957,7 +953,7 @@ class ExtractMedia():
A copy of :attr:`image` in BGR color format """ A copy of :attr:`image` in BGR color format """
return self.image[..., :3].copy() return self.image[..., :3].copy()
def _image_as_rgb(self) -> "np.ndarray": def _image_as_rgb(self) -> np.ndarray:
""" Get a copy of the source frame in RGB format. """ Get a copy of the source frame in RGB format.
Returns Returns
@ -966,7 +962,7 @@ class ExtractMedia():
A copy of :attr:`image` in RGB color format """ A copy of :attr:`image` in RGB color format """
return self.image[..., 2::-1].copy() return self.image[..., 2::-1].copy()
def _image_as_gray(self) -> "np.ndarray": def _image_as_gray(self) -> np.ndarray:
""" Get a copy of the source frame in gray-scale format. """ Get a copy of the source frame in gray-scale format.
Returns Returns

View file

@ -17,7 +17,6 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
import typing as T import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -31,13 +30,8 @@ from lib.utils import FaceswapError
from plugins.extract._base import BatchType, Extractor, ExtractorBatch from plugins.extract._base import BatchType, Extractor, ExtractorBatch
from plugins.extract.pipeline import ExtractMedia from plugins.extract.pipeline import ExtractMedia
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Generator
from queue import Queue from queue import Queue
from lib.align.aligned_face import CenteringType from lib.align.aligned_face import CenteringType
@ -50,8 +44,8 @@ class RecogBatch(ExtractorBatch):
Inherits from :class:`~plugins.extract._base.ExtractorBatch` Inherits from :class:`~plugins.extract._base.ExtractorBatch`
""" """
detected_faces: T.List["DetectedFace"] = field(default_factory=list) detected_faces: list[DetectedFace] = field(default_factory=list)
feed_faces: T.List[AlignedFace] = field(default_factory=list) feed_faces: list[AlignedFace] = field(default_factory=list)
class Identity(Extractor): # pylint:disable=abstract-method class Identity(Extractor): # pylint:disable=abstract-method
@ -82,9 +76,9 @@ class Identity(Extractor): # pylint:disable=abstract-method
""" """
def __init__(self, def __init__(self,
git_model_id: T.Optional[int] = None, git_model_id: int | None = None,
model_filename: T.Optional[str] = None, model_filename: str | None = None,
configfile: T.Optional[str] = None, configfile: str | None = None,
instance: int = 0, instance: int = 0,
**kwargs): **kwargs):
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
@ -119,7 +113,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
logger.debug("Obtained detected face: (filename: %s, detected_face: %s)", logger.debug("Obtained detected face: (filename: %s, detected_face: %s)",
item.filename, item.detected_faces) item.filename, item.detected_faces)
def get_batch(self, queue: Queue) -> T.Tuple[bool, RecogBatch]: def get_batch(self, queue: Queue) -> tuple[bool, RecogBatch]:
""" Get items for inputting into the recognition from the queue in batches """ Get items for inputting into the recognition from the queue in batches
Items are returned from the ``queue`` in batches of Items are returned from the ``queue`` in batches of
@ -226,7 +220,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
"\n3) Enable 'Single Process' mode.") "\n3) Enable 'Single Process' mode.")
raise FaceswapError(msg) from err raise FaceswapError(msg) from err
def finalize(self, batch: BatchType) -> T.Generator[ExtractMedia, None, None]: def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]:
""" Finalize the output from Masker """ Finalize the output from Masker
This should be called as the final task of each `plugin`. This should be called as the final task of each `plugin`.
@ -301,8 +295,8 @@ class IdentityFilter():
def __init__(self, save_output: bool) -> None: def __init__(self, save_output: bool) -> None:
logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output) logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output)
self._save_output = save_output self._save_output = save_output
self._filter: T.Optional[np.ndarray] = None self._filter: np.ndarray | None = None
self._nfilter: T.Optional[np.ndarray] = None self._nfilter: np.ndarray | None = None
self._threshold = 0.0 self._threshold = 0.0
self._filter_enabled: bool = False self._filter_enabled: bool = False
self._nfilter_enabled: bool = False self._nfilter_enabled: bool = False
@ -357,7 +351,7 @@ class IdentityFilter():
return retval return retval
def _get_matches(self, def _get_matches(self,
filter_type: Literal["filter", "nfilter"], filter_type: T.Literal["filter", "nfilter"],
identities: np.ndarray) -> np.ndarray: identities: np.ndarray) -> np.ndarray:
""" Obtain the average and minimum distances for each face against the source identities """ Obtain the average and minimum distances for each face against the source identities
to test against to test against
@ -386,9 +380,9 @@ class IdentityFilter():
return retval return retval
def _filter_faces(self, def _filter_faces(self,
faces: T.List[DetectedFace], faces: list[DetectedFace],
sub_folders: T.List[T.Optional[str]], sub_folders: list[str | None],
should_filter: T.List[bool]) -> T.List[DetectedFace]: should_filter: list[bool]) -> list[DetectedFace]:
""" Filter the detected faces, either removing filtered faces from the list of detected """ Filter the detected faces, either removing filtered faces from the list of detected
faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if
saving output is enabled. saving output is enabled.
@ -410,7 +404,7 @@ class IdentityFilter():
The filtered list of detected face objects, if saving filtered faces has not been The filtered list of detected face objects, if saving filtered faces has not been
selected or the full list of detected faces selected or the full list of detected faces
""" """
retval: T.List[DetectedFace] = [] retval: list[DetectedFace] = []
self._counts += sum(should_filter) self._counts += sum(should_filter)
for idx, face in enumerate(faces): for idx, face in enumerate(faces):
fldr = sub_folders[idx] fldr = sub_folders[idx]
@ -429,8 +423,8 @@ class IdentityFilter():
return retval return retval
def __call__(self, def __call__(self,
faces: T.List[DetectedFace], faces: list[DetectedFace],
sub_folders: T.List[T.Optional[str]]) -> T.List[DetectedFace]: sub_folders: list[str | None]) -> list[DetectedFace]:
""" Call the identity filter function """ Call the identity filter function
Parameters Parameters
@ -459,14 +453,14 @@ class IdentityFilter():
logger.trace("All faces already filtered: %s", sub_folders) # type: ignore logger.trace("All faces already filtered: %s", sub_folders) # type: ignore
return faces return faces
should_filter: T.List[np.ndarray] = [] should_filter: list[np.ndarray] = []
for f_type in get_args(Literal["filter", "nfilter"]): for f_type in T.get_args(T.Literal["filter", "nfilter"]):
if not getattr(self, f"_{f_type}_enabled"): if not getattr(self, f"_{f_type}_enabled"):
continue continue
should_filter.append(self._get_matches(f_type, identities)) should_filter.append(self._get_matches(f_type, identities))
# If any of the filter or nfilter evaluate to 'should filter' then filter out face # If any of the filter or nfilter evaluate to 'should filter' then filter out face
final_filter: T.List[bool] = np.array(should_filter).max(axis=0).tolist() final_filter: list[bool] = np.array(should_filter).max(axis=0).tolist()
logger.trace("should_filter: %s, final_filter: %s", # type: ignore logger.trace("should_filter: %s, final_filter: %s", # type: ignore
should_filter, final_filter) should_filter, final_filter)
return self._filter_faces(faces, sub_folders, final_filter) return self._filter_faces(faces, sub_folders, final_filter)

View file

@ -1,10 +1,9 @@
#!/usr/bin python3 #!/usr/bin python3
""" VGG_Face2 inference and sorting """ """ VGG_Face2 inference and sorting """
from __future__ import annotations
import logging import logging
import sys import typing as T
from typing import cast, Dict, Generator, List, Tuple, Optional
import numpy as np import numpy as np
import psutil import psutil
@ -15,11 +14,8 @@ from lib.model.session import KSession
from lib.utils import FaceswapError from lib.utils import FaceswapError
from ._base import BatchType, RecogBatch, Identity from ._base import BatchType, RecogBatch, Identity
if T.TYPE_CHECKING:
if sys.version_info < (3, 8): from collections.abc import Generator
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -64,7 +60,7 @@ class Recognition(Identity):
def init_model(self) -> None: def init_model(self) -> None:
""" Initialize VGG Face 2 Model. """ """ Initialize VGG Face 2 Model. """
assert isinstance(self.model_path, str) assert isinstance(self.model_path, str)
model_kwargs = dict(custom_objects={'L2_normalize': L2_normalize}) model_kwargs = {"custom_objects": {"L2_normalize": L2_normalize}}
self.model = KSession(self.name, self.model = KSession(self.name,
self.model_path, self.model_path,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
@ -76,7 +72,7 @@ class Recognition(Identity):
def process_input(self, batch: BatchType) -> None: def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """ """ Compile the detected faces for prediction """
assert isinstance(batch, RecogBatch) assert isinstance(batch, RecogBatch)
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3] batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
for feed in batch.feed_faces], for feed in batch.feed_faces],
dtype="float32") - self._average_img dtype="float32") - self._average_img
logger.trace("feed shape: %s", batch.feed.shape) # type:ignore logger.trace("feed shape: %s", batch.feed.shape) # type:ignore
@ -121,15 +117,15 @@ class Cluster(): # pylint: disable=too-few-public-methods
def __init__(self, def __init__(self,
predictions: np.ndarray, predictions: np.ndarray,
method: Literal["single", "centroid", "median", "ward"], method: T.Literal["single", "centroid", "median", "ward"],
threshold: Optional[float] = None) -> None: threshold: float | None = None) -> None:
logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)", logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)",
self.__class__.__name__, predictions.shape, method, threshold) self.__class__.__name__, predictions.shape, method, threshold)
self._num_predictions = predictions.shape[0] self._num_predictions = predictions.shape[0]
self._should_output_bins = threshold is not None self._should_output_bins = threshold is not None
self._threshold = 0.0 if threshold is None else threshold self._threshold = 0.0 if threshold is None else threshold
self._bins: Dict[int, int] = {} self._bins: dict[int, int] = {}
self._iterator = self._integer_iterator() self._iterator = self._integer_iterator()
self._result_linkage = self._do_linkage(predictions, method) self._result_linkage = self._do_linkage(predictions, method)
@ -192,7 +188,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
def _do_linkage(self, def _do_linkage(self,
predictions: np.ndarray, predictions: np.ndarray,
method: Literal["single", "centroid", "median", "ward"]) -> np.ndarray: method: T.Literal["single", "centroid", "median", "ward"]) -> np.ndarray:
""" Use FastCluster to perform vector or standard linkage """ Use FastCluster to perform vector or standard linkage
Parameters Parameters
@ -218,7 +214,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
def _process_leaf_node(self, def _process_leaf_node(self,
current_index: int, current_index: int,
current_bin: int) -> List[Tuple[int, int]]: current_bin: int) -> list[tuple[int, int]]:
""" Process the output when we have hit a leaf node """ """ Process the output when we have hit a leaf node """
if not self._should_output_bins: if not self._should_output_bins:
return [(current_index, 0)] return [(current_index, 0)]
@ -263,7 +259,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
tree: np.ndarray, tree: np.ndarray,
points: int, points: int,
current_index: int, current_index: int,
current_bin: int = 0) -> List[Tuple[int, int]]: current_bin: int = 0) -> list[tuple[int, int]]:
""" Seriation method for sorted similarity. """ Seriation method for sorted similarity.
Seriation computes the order implied by a hierarchical tree (dendrogram). Seriation computes the order implied by a hierarchical tree (dendrogram).
@ -298,7 +294,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
return serate_left + serate_right # type: ignore return serate_left + serate_right # type: ignore
def __call__(self) -> List[Tuple[int, int]]: def __call__(self) -> list[tuple[int, int]]:
""" Process the linkages. """ Process the linkages.
Transforms a distance matrix into a sorted distance matrix according to the order implied Transforms a distance matrix into a sorted distance matrix according to the order implied

View file

@ -1,13 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Plugin loader for Faceswap extract, training and convert tasks """ """ Plugin loader for Faceswap extract, training and convert tasks """
from __future__ import annotations
import logging import logging
import os import os
import sys import typing as T
from importlib import import_module
from typing import Callable, List, Type, TYPE_CHECKING
if TYPE_CHECKING: from importlib import import_module
if T.TYPE_CHECKING:
from collections.abc import Callable
from plugins.extract.detect._base import Detector from plugins.extract.detect._base import Detector
from plugins.extract.align._base import Aligner from plugins.extract.align._base import Aligner
from plugins.extract.mask._base import Masker from plugins.extract.mask._base import Masker
@ -15,11 +16,6 @@ if TYPE_CHECKING:
from plugins.train.model._base import ModelBase from plugins.train.model._base import ModelBase
from plugins.train.trainer._base import TrainerBase from plugins.train.trainer._base import TrainerBase
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -36,7 +32,7 @@ class PluginLoader():
>>> aligner = PluginLoader.get_aligner('cv2-dnn') >>> aligner = PluginLoader.get_aligner('cv2-dnn')
""" """
@staticmethod @staticmethod
def get_detector(name: str, disable_logging: bool = False) -> Type["Detector"]: def get_detector(name: str, disable_logging: bool = False) -> type[Detector]:
""" Return requested detector plugin """ Return requested detector plugin
Parameters Parameters
@ -55,7 +51,7 @@ class PluginLoader():
return PluginLoader._import("extract.detect", name, disable_logging) return PluginLoader._import("extract.detect", name, disable_logging)
@staticmethod @staticmethod
def get_aligner(name: str, disable_logging: bool = False) -> Type["Aligner"]: def get_aligner(name: str, disable_logging: bool = False) -> type[Aligner]:
""" Return requested aligner plugin """ Return requested aligner plugin
Parameters Parameters
@ -74,7 +70,7 @@ class PluginLoader():
return PluginLoader._import("extract.align", name, disable_logging) return PluginLoader._import("extract.align", name, disable_logging)
@staticmethod @staticmethod
def get_masker(name: str, disable_logging: bool = False) -> Type["Masker"]: def get_masker(name: str, disable_logging: bool = False) -> type[Masker]:
""" Return requested masker plugin """ Return requested masker plugin
Parameters Parameters
@ -93,7 +89,7 @@ class PluginLoader():
return PluginLoader._import("extract.mask", name, disable_logging) return PluginLoader._import("extract.mask", name, disable_logging)
@staticmethod @staticmethod
def get_recognition(name: str, disable_logging: bool = False) -> Type["Identity"]: def get_recognition(name: str, disable_logging: bool = False) -> type[Identity]:
""" Return requested recognition plugin """ Return requested recognition plugin
Parameters Parameters
@ -112,7 +108,7 @@ class PluginLoader():
return PluginLoader._import("extract.recognition", name, disable_logging) return PluginLoader._import("extract.recognition", name, disable_logging)
@staticmethod @staticmethod
def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]: def get_model(name: str, disable_logging: bool = False) -> type[ModelBase]:
""" Return requested training model plugin """ Return requested training model plugin
Parameters Parameters
@ -131,7 +127,7 @@ class PluginLoader():
return PluginLoader._import("train.model", name, disable_logging) return PluginLoader._import("train.model", name, disable_logging)
@staticmethod @staticmethod
def get_trainer(name: str, disable_logging: bool = False) -> Type["TrainerBase"]: def get_trainer(name: str, disable_logging: bool = False) -> type[TrainerBase]:
""" Return requested training trainer plugin """ Return requested training trainer plugin
Parameters Parameters
@ -198,9 +194,9 @@ class PluginLoader():
return getattr(module, ttl) return getattr(module, ttl)
@staticmethod @staticmethod
def get_available_extractors(extractor_type: Literal["align", "detect", "mask"], def get_available_extractors(extractor_type: T.Literal["align", "detect", "mask"],
add_none: bool = False, add_none: bool = False,
extend_plugin: bool = False) -> List[str]: extend_plugin: bool = False) -> list[str]:
""" Return a list of available extractors of the given type """ Return a list of available extractors of the given type
Parameters Parameters
@ -243,7 +239,7 @@ class PluginLoader():
return extractors return extractors
@staticmethod @staticmethod
def get_available_models() -> List[str]: def get_available_models() -> list[str]:
""" Return a list of available training models """ Return a list of available training models
Returns Returns
@ -273,7 +269,7 @@ class PluginLoader():
return 'original' if 'original' in models else models[0] return 'original' if 'original' in models else models[0]
@staticmethod @staticmethod
def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> List[str]: def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> list[str]:
""" Return a list of available converter plugins in the given category """ Return a list of available converter plugins in the given category
Parameters Parameters

View file

@ -249,6 +249,31 @@ class Config(FaceswapConfig):
"NB: The value given here is the 'exponent' to the epsilon. For example, " "NB: The value given here is the 'exponent' to the epsilon. For example, "
"choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the epsilon " "choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the epsilon "
"to 0.001 (1e-3).")) "to 0.001 (1e-3)."))
self.add_item(
section=section,
title="save_optimizer",
datatype=str,
group=_("optimizer"),
default="exit",
fixed=False,
gui_radio=True,
choices=["never", "always", "exit"],
info=_(
"When to save the Optimizer Weights. Saving the optimizer weights is not "
"necessary and will increase the model file size 3x (and by extension the amount "
"of time it takes to save the model). However, it can be useful to save these "
"weights if you want to guarantee that a resumed model carries off exactly from "
"where it left off, rather than spending a few hundred iterations catching up."
"\n\t never - Don't save optimizer weights."
"\n\t always - Save the optimizer weights at every save iteration. Model saving "
"will take longer, due to the increased file size, but you will always have the "
"last saved optimizer state in your model file."
"\n\t exit - Only save the optimizer weights when explicitly terminating a "
"model. This can be when the model is actively stopped or when the target "
"iterations are met. Note: If the training session ends because of another "
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
"optimizer weights will NOT be saved."))
self.add_item( self.add_item(
section=section, section=section,
title="autoclip", title="autoclip",

View file

@ -21,11 +21,6 @@ from tensorflow.keras.models import load_model, Model as KModel # noqa:E501 #
from lib.model.backup_restore import Backup from lib.model.backup_restore import Backup
from lib.utils import FaceswapError from lib.utils import FaceswapError
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from tensorflow import keras from tensorflow import keras
from .model import ModelBase from .model import ModelBase
@ -35,7 +30,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def get_all_sub_models( def get_all_sub_models(
model: keras.models.Model, model: keras.models.Model,
models: T.Optional[T.List[keras.models.Model]] = None) -> T.List[keras.models.Model]: models: list[keras.models.Model] | None = None) -> list[keras.models.Model]:
""" For a given model, return all sub-models that occur (recursively) as children. """ For a given model, return all sub-models that occur (recursively) as children.
Parameters Parameters
@ -85,12 +80,12 @@ class IO():
plugin: ModelBase, plugin: ModelBase,
model_dir: str, model_dir: str,
is_predict: bool, is_predict: bool,
save_optimizer: Literal["never", "always", "exit"]) -> None: save_optimizer: T.Literal["never", "always", "exit"]) -> None:
self._plugin = plugin self._plugin = plugin
self._is_predict = is_predict self._is_predict = is_predict
self._model_dir = model_dir self._model_dir = model_dir
self._save_optimizer = save_optimizer self._save_optimizer = save_optimizer
self._history: T.List[T.List[float]] = [[], []] # Loss histories per save iteration self._history: list[list[float]] = [[], []] # Loss histories per save iteration
self._backup = Backup(self._model_dir, self._plugin.name) self._backup = Backup(self._model_dir, self._plugin.name)
@property @property
@ -106,12 +101,12 @@ class IO():
return os.path.isfile(self._filename) return os.path.isfile(self._filename)
@property @property
def history(self) -> T.List[T.List[float]]: def history(self) -> list[list[float]]:
""" list: list of loss histories per side for the current save iteration. """ """ list: list of loss histories per side for the current save iteration. """
return self._history return self._history
@property @property
def multiple_models_in_folder(self) -> T.Optional[T.List[str]]: def multiple_models_in_folder(self) -> list[str] | None:
""" :list: or ``None`` If there are multiple model types in the requested folder, or model """ :list: or ``None`` If there are multiple model types in the requested folder, or model
types that don't correspond to the requested plugin type, then returns the list of plugin types that don't correspond to the requested plugin type, then returns the list of plugin
names that exist in the folder, otherwise returns ``None`` """ names that exist in the folder, otherwise returns ``None`` """
@ -210,7 +205,7 @@ class IO():
msg += f" - Average loss since last save: {', '.join(lossmsg)}" msg += f" - Average loss since last save: {', '.join(lossmsg)}"
logger.info(msg) logger.info(msg)
def _get_save_averages(self) -> T.List[float]: def _get_save_averages(self) -> list[float]:
""" Return the average loss since the last save iteration and reset historical loss """ """ Return the average loss since the last save iteration and reset historical loss """
logger.debug("Getting save averages") logger.debug("Getting save averages")
if not all(loss for loss in self._history): if not all(loss for loss in self._history):
@ -222,7 +217,7 @@ class IO():
logger.debug("Average losses since last save: %s", retval) logger.debug("Average losses since last save: %s", retval)
return retval return retval
def _should_backup(self, save_averages: T.List[float]) -> bool: def _should_backup(self, save_averages: list[float]) -> bool:
""" Check whether the loss averages for this save iteration is the lowest that has been """ Check whether the loss averages for this save iteration is the lowest that has been
seen. seen.
@ -301,7 +296,7 @@ class Weights():
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@classmethod @classmethod
def _check_weights_file(cls, weights_file: str) -> T.Optional[str]: def _check_weights_file(cls, weights_file: str) -> str | None:
""" Validate that we have a valid path to a .h5 file. """ Validate that we have a valid path to a .h5 file.
Parameters Parameters
@ -403,7 +398,7 @@ class Weights():
"different settings than you have set for your current model.", "different settings than you have set for your current model.",
skipped_ops) skipped_ops)
def _get_weights_model(self) -> T.List[keras.models.Model]: def _get_weights_model(self) -> list[keras.models.Model]:
""" Obtain a list of all sub-models contained within the weights model. """ Obtain a list of all sub-models contained within the weights model.
Returns Returns
@ -429,7 +424,7 @@ class Weights():
def _load_layer_weights(self, def _load_layer_weights(self,
layer: keras.layers.Layer, layer: keras.layers.Layer,
sub_weights: keras.layers.Layer, sub_weights: keras.layers.Layer,
model_name: str) -> Literal[-1, 0, 1]: model_name: str) -> T.Literal[-1, 0, 1]:
""" Load the weights for a single layer. """ Load the weights for a single layer.
Parameters Parameters

View file

@ -29,18 +29,12 @@ from plugins.train._config import Config
from .io import IO, get_all_sub_models, Weights from .io import IO, get_all_sub_models, Weights
from .settings import Loss, Optimizer, Settings from .settings import Loss, Optimizer, Settings
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
import argparse import argparse
from lib.config import ConfigValueType from lib.config import ConfigValueType
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_CONFIG: T.Dict[str, ConfigValueType] = {} _CONFIG: dict[str, ConfigValueType] = {}
class ModelBase(): class ModelBase():
@ -79,13 +73,13 @@ class ModelBase():
self.__class__.__name__, model_dir, arguments, predict) self.__class__.__name__, model_dir, arguments, predict)
# Input shape must be set within the plugin after initializing # Input shape must be set within the plugin after initializing
self.input_shape: T.Tuple[int, ...] = () self.input_shape: tuple[int, ...] = ()
self.trainer = "original" # Override for plugin specific trainer self.trainer = "original" # Override for plugin specific trainer
self.color_order: Literal["bgr", "rgb"] = "bgr" # Override for image color channel order self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
self._args = arguments self._args = arguments
self._is_predict = predict self._is_predict = predict
self._model: T.Optional[tf.keras.models.Model] = None self._model: tf.keras.models.Model | None = None
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
self._load_config() self._load_config()
@ -100,14 +94,7 @@ class ModelBase():
"use. Please select a mask or disable 'Learn Mask'.") "use. Please select a mask or disable 'Learn Mask'.")
self._mixed_precision = self.config["mixed_precision"] self._mixed_precision = self.config["mixed_precision"]
# self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"]) self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
# TODO - Re-enable saving of optimizer once this bug is fixed:
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
# File "h5py/h5d.pyx", line 87, in h5py.h5d.create
# ValueError: Unable to create dataset (name already exists)
self._io = IO(self, model_dir, self._is_predict, "never")
self._check_multiple_models() self._check_multiple_models()
self._state = State(model_dir, self._state = State(model_dir,
@ -175,16 +162,16 @@ class ModelBase():
return self.name return self.name
@property @property
def input_shapes(self) -> T.List[T.Tuple[None, int, int, int]]: def input_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the inputs to the model. """ """ list: A flattened list corresponding to all of the inputs to the model. """
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(inputs)) shapes = [T.cast(tuple[None, int, int, int], K.int_shape(inputs))
for inputs in self.model.inputs] for inputs in self.model.inputs]
return shapes return shapes
@property @property
def output_shapes(self) -> T.List[T.Tuple[None, int, int, int]]: def output_shapes(self) -> list[tuple[None, int, int, int]]:
""" list: A flattened list corresponding to all of the outputs of the model. """ """ list: A flattened list corresponding to all of the outputs of the model. """
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(output)) shapes = [T.cast(tuple[None, int, int, int], K.int_shape(output))
for output in self.model.outputs] for output in self.model.outputs]
return shapes return shapes
@ -333,7 +320,7 @@ class ModelBase():
a list of 2 shape tuples of 3 dimensions. """ a list of 2 shape tuples of 3 dimensions. """
assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple" assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple"
def _get_inputs(self) -> T.List[tf.keras.layers.Input]: def _get_inputs(self) -> list[tf.keras.layers.Input]:
""" Obtain the standardized inputs for the model. """ Obtain the standardized inputs for the model.
The inputs will be returned for the "A" and "B" sides in the shape as defined by The inputs will be returned for the "A" and "B" sides in the shape as defined by
@ -352,7 +339,7 @@ class ModelBase():
logger.debug("inputs: %s", inputs) logger.debug("inputs: %s", inputs)
return inputs return inputs
def build_model(self, inputs: T.List[tf.keras.layers.Input]) -> tf.keras.models.Model: def build_model(self, inputs: list[tf.keras.layers.Input]) -> tf.keras.models.Model:
""" Override for Model Specific autoencoder builds. """ Override for Model Specific autoencoder builds.
Parameters Parameters
@ -427,7 +414,7 @@ class ModelBase():
self._state.add_session_loss_names(self._loss.names) self._state.add_session_loss_names(self._loss.names)
logger.debug("Compiled Model: %s", self.model) logger.debug("Compiled Model: %s", self.model)
def _legacy_mapping(self) -> T.Optional[dict]: def _legacy_mapping(self) -> dict | None:
""" The mapping of separate model files to single model layers for transferring of legacy """ The mapping of separate model files to single model layers for transferring of legacy
weights. weights.
@ -439,7 +426,7 @@ class ModelBase():
""" """
return None return None
def add_history(self, loss: T.List[float]) -> None: def add_history(self, loss: list[float]) -> None:
""" Add the current iteration's loss history to :attr:`_io.history`. """ Add the current iteration's loss history to :attr:`_io.history`.
Called from the trainer after each iteration, for tracking loss drop over time between Called from the trainer after each iteration, for tracking loss drop over time between
@ -482,18 +469,18 @@ class State():
self._filename = os.path.join(model_dir, filename) self._filename = os.path.join(model_dir, filename)
self._name = model_name self._name = model_name
self._iterations = 0 self._iterations = 0
self._mixed_precision_layers: T.List[str] = [] self._mixed_precision_layers: list[str] = []
self._rebuild_model = False self._rebuild_model = False
self._sessions: T.Dict[int, dict] = {} self._sessions: dict[int, dict] = {}
self._lowest_avg_loss: T.Dict[str, float] = {} self._lowest_avg_loss: dict[str, float] = {}
self._config: T.Dict[str, ConfigValueType] = {} self._config: dict[str, ConfigValueType] = {}
self._load(config_changeable_items) self._load(config_changeable_items)
self._session_id = self._new_session_id() self._session_id = self._new_session_id()
self._create_new_session(no_logs, config_changeable_items) self._create_new_session(no_logs, config_changeable_items)
logger.debug("Initialized %s:", self.__class__.__name__) logger.debug("Initialized %s:", self.__class__.__name__)
@property @property
def loss_names(self) -> T.List[str]: def loss_names(self) -> list[str]:
""" list: The loss names for the current session """ """ list: The loss names for the current session """
return self._sessions[self._session_id]["loss_names"] return self._sessions[self._session_id]["loss_names"]
@ -518,7 +505,7 @@ class State():
return self._session_id return self._session_id
@property @property
def mixed_precision_layers(self) -> T.List[str]: def mixed_precision_layers(self) -> list[str]:
"""list: Layers that can be switched between mixed-float16 and float32. """ """list: Layers that can be switched between mixed-float16 and float32. """
return self._mixed_precision_layers return self._mixed_precision_layers
@ -564,7 +551,7 @@ class State():
"iterations": 0, "iterations": 0,
"config": config_changeable_items} "config": config_changeable_items}
def add_session_loss_names(self, loss_names: T.List[str]) -> None: def add_session_loss_names(self, loss_names: list[str]) -> None:
""" Add the session loss names to the sessions dictionary. """ Add the session loss names to the sessions dictionary.
The loss names are used for Tensorboard logging The loss names are used for Tensorboard logging
@ -593,7 +580,7 @@ class State():
self._iterations += 1 self._iterations += 1
self._sessions[self._session_id]["iterations"] += 1 self._sessions[self._session_id]["iterations"] += 1
def add_mixed_precision_layers(self, layers: T.List[str]) -> None: def add_mixed_precision_layers(self, layers: list[str]) -> None:
""" Add the list of model's layers that are compatible for mixed precision to the """ Add the list of model's layers that are compatible for mixed precision to the
state dictionary """ state dictionary """
logger.debug("Storing mixed precision layers: %s", layers) logger.debug("Storing mixed precision layers: %s", layers)
@ -655,11 +642,11 @@ class State():
legacy_update = self._update_legacy_config() legacy_update = self._update_legacy_config()
# Add any new items to state config for legacy purposes where the new default may be # Add any new items to state config for legacy purposes where the new default may be
# detrimental to an existing model. # detrimental to an existing model.
legacy_defaults: T.Dict[str, T.Union[str, int, bool]] = {"centering": "legacy", legacy_defaults: dict[str, str | int | bool] = {"centering": "legacy",
"mask_loss_function": "mse", "mask_loss_function": "mse",
"l2_reg_term": 100, "l2_reg_term": 100,
"optimizer": "adam", "optimizer": "adam",
"mixed_precision": False} "mixed_precision": False}
for key, val in _CONFIG.items(): for key, val in _CONFIG.items():
if key not in self._config.keys(): if key not in self._config.keys():
setting: ConfigValueType = legacy_defaults.get(key, val) setting: ConfigValueType = legacy_defaults.get(key, val)
@ -807,7 +794,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
""" :class:`keras.models.Model`: The Faceswap model, compiled for inference. """ """ :class:`keras.models.Model`: The Faceswap model, compiled for inference. """
return self._model return self._model
def _get_nodes(self, nodes: np.ndarray) -> T.List[T.Tuple[str, int]]: def _get_nodes(self, nodes: np.ndarray) -> list[tuple[str, int]]:
""" Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary, """ Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary,
filters the layer name(s) and output index of the node, splitting to the correct output filters the layer name(s) and output index of the node, splitting to the correct output
index in the event of multiple inputs. index in the event of multiple inputs.
@ -849,7 +836,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
logger.debug("Compiling inference model. saved_model: %s", saved_model) logger.debug("Compiling inference model. saved_model: %s", saved_model)
struct = self._get_filtered_structure() struct = self._get_filtered_structure()
model_inputs = self._get_inputs(saved_model.inputs) model_inputs = self._get_inputs(saved_model.inputs)
compiled_layers: T.Dict[str, tf.keras.layers.Layer] = {} compiled_layers: dict[str, tf.keras.layers.Layer] = {}
for layer in saved_model.layers: for layer in saved_model.layers:
if layer.name not in struct: if layer.name not in struct:
logger.debug("Skipping unused layer: '%s'", layer.name) logger.debug("Skipping unused layer: '%s'", layer.name)

View file

@ -14,7 +14,6 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
import platform import platform
import sys
import typing as T import typing as T
from contextlib import nullcontext from contextlib import nullcontext
@ -28,12 +27,9 @@ from lib.model import losses, optimizers
from lib.model.autoclip import AutoClipper from lib.model.autoclip import AutoClipper
from lib.utils import get_backend from lib.utils import get_backend
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Callable
from contextlib import AbstractContextManager as ContextManager
from argparse import Namespace from argparse import Namespace
from .model import State from .model import State
@ -58,9 +54,9 @@ class LossClass:
kwargs: dict kwargs: dict
Any keyword arguments to supply to the loss function at initialization. Any keyword arguments to supply to the loss function at initialization.
""" """
function: T.Union[T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor], T.Any] = k_losses.mae function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | T.Any = k_losses.mae
init: bool = True init: bool = True
kwargs: T.Dict[str, T.Any] = field(default_factory=dict) kwargs: dict[str, T.Any] = field(default_factory=dict)
class Loss(): class Loss():
@ -73,13 +69,13 @@ class Loss():
color_order: str color_order: str
Color order of the model. One of `"BGR"` or `"RGB"` Color order of the model. One of `"BGR"` or `"RGB"`
""" """
def __init__(self, config: dict, color_order: Literal["bgr", "rgb"]) -> None: def __init__(self, config: dict, color_order: T.Literal["bgr", "rgb"]) -> None:
logger.debug("Initializing %s: (color_order: %s)", self.__class__.__name__, color_order) logger.debug("Initializing %s: (color_order: %s)", self.__class__.__name__, color_order)
self._config = config self._config = config
self._mask_channels = self._get_mask_channels() self._mask_channels = self._get_mask_channels()
self._inputs: T.List[tf.keras.layers.Layer] = [] self._inputs: list[tf.keras.layers.Layer] = []
self._names: T.List[str] = [] self._names: list[str] = []
self._funcs: T.Dict[str, T.Callable] = {} self._funcs: dict[str, Callable] = {}
self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss), self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss),
"flip": LossClass(function=losses.LDRFLIPLoss, "flip": LossClass(function=losses.LDRFLIPLoss,
@ -104,7 +100,7 @@ class Loss():
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@property @property
def names(self) -> T.List[str]: def names(self) -> list[str]:
""" list: The list of loss names for the model. """ """ list: The list of loss names for the model. """
return self._names return self._names
@ -114,14 +110,14 @@ class Loss():
return self._funcs return self._funcs
@property @property
def _mask_inputs(self) -> T.Optional[list]: def _mask_inputs(self) -> list | None:
""" list: The list of input tensors to the model that contain the mask. Returns ``None`` """ list: The list of input tensors to the model that contain the mask. Returns ``None``
if there is no mask input to the model. """ if there is no mask input to the model. """
mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")] mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")]
return None if not mask_inputs else mask_inputs return None if not mask_inputs else mask_inputs
@property @property
def _mask_shapes(self) -> T.Optional[T.List[tuple]]: def _mask_shapes(self) -> list[tuple] | None:
""" list: The list of shape tuples for the mask input tensors for the model. Returns """ list: The list of shape tuples for the mask input tensors for the model. Returns
``None`` if there is no mask input. """ ``None`` if there is no mask input. """
if self._mask_inputs is None: if self._mask_inputs is None:
@ -141,7 +137,7 @@ class Loss():
self._set_loss_functions(model.output_names) self._set_loss_functions(model.output_names)
self._names.insert(0, "total") self._names.insert(0, "total")
def _set_loss_names(self, outputs: T.List[tf.Tensor]) -> None: def _set_loss_names(self, outputs: list[tf.Tensor]) -> None:
""" Name the losses based on model output. """ Name the losses based on model output.
This is used for correct naming in the state file, for display purposes only. This is used for correct naming in the state file, for display purposes only.
@ -173,7 +169,7 @@ class Loss():
self._names.append(f"{name}_{side}{suffix}") self._names.append(f"{name}_{side}{suffix}")
logger.debug(self._names) logger.debug(self._names)
def _get_function(self, name: str) -> T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: def _get_function(self, name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
""" Obtain the requested Loss function """ Obtain the requested Loss function
Parameters Parameters
@ -191,7 +187,7 @@ class Loss():
logger.debug("Obtained loss function `%s` (%s)", name, retval) logger.debug("Obtained loss function `%s` (%s)", name, retval)
return retval return retval
def _set_loss_functions(self, output_names: T.List[str]): def _set_loss_functions(self, output_names: list[str]):
""" Set the loss functions and their associated weights. """ Set the loss functions and their associated weights.
Adds the loss functions to the :attr:`functions` dictionary. Adds the loss functions to the :attr:`functions` dictionary.
@ -251,7 +247,7 @@ class Loss():
mask_channel=mask_channel) mask_channel=mask_channel)
channel_idx += 1 channel_idx += 1
def _get_mask_channels(self) -> T.List[int]: def _get_mask_channels(self) -> list[int]:
""" Obtain the channels from the face targets that the masks reside in from the training """ Obtain the channels from the face targets that the masks reside in from the training
data generator. data generator.
@ -311,8 +307,8 @@ class Optimizer(): # pylint:disable=too-few-public-methods
{"beta_1": 0.5, "beta_2": 0.99, "epsilon": epsilon}), {"beta_1": 0.5, "beta_2": 0.99, "epsilon": epsilon}),
"rms-prop": (optimizers.RMSprop, {"epsilon": epsilon})} "rms-prop": (optimizers.RMSprop, {"epsilon": epsilon})}
optimizer_info = valid_optimizers[optimizer] optimizer_info = valid_optimizers[optimizer]
self._optimizer: T.Callable = optimizer_info[0] self._optimizer: Callable = optimizer_info[0]
self._kwargs: T.Dict[str, T.Any] = optimizer_info[1] self._kwargs: dict[str, T.Any] = optimizer_info[1]
self._configure(learning_rate, autoclip) self._configure(learning_rate, autoclip)
logger.verbose("Using %s optimizer", optimizer.title()) # type:ignore[attr-defined] logger.verbose("Using %s optimizer", optimizer.title()) # type:ignore[attr-defined]
@ -411,7 +407,7 @@ class Settings():
return mixedprecision.LossScaleOptimizer(optimizer) # pylint:disable=no-member return mixedprecision.LossScaleOptimizer(optimizer) # pylint:disable=no-member
@classmethod @classmethod
def _set_tf_settings(cls, allow_growth: bool, exclude_devices: T.List[int]) -> None: def _set_tf_settings(cls, allow_growth: bool, exclude_devices: list[int]) -> None:
""" Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth. """ Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth.
Enables the Tensorflow allow_growth option if requested in the command line arguments Enables the Tensorflow allow_growth option if requested in the command line arguments
@ -480,8 +476,8 @@ class Settings():
return True return True
def _get_strategy(self, def _get_strategy(self,
strategy: Literal["default", "central-storage", "mirrored"] strategy: T.Literal["default", "central-storage", "mirrored"]
) -> T.Optional[tf.distribute.Strategy]: ) -> tf.distribute.Strategy | None:
""" If we are running on Nvidia backend and the strategy is not ``None`` then return """ If we are running on Nvidia backend and the strategy is not ``None`` then return
the correct tensorflow distribution strategy, otherwise return ``None``. the correct tensorflow distribution strategy, otherwise return ``None``.
@ -565,7 +561,7 @@ class Settings():
return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0") return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0")
def _get_mixed_precision_layers(self, layers: T.List[dict]) -> T.List[str]: def _get_mixed_precision_layers(self, layers: list[dict]) -> list[str]:
""" Obtain the names of the layers in a mixed precision model that have their dtype policy """ Obtain the names of the layers in a mixed precision model that have their dtype policy
explicitly set to mixed-float16. explicitly set to mixed-float16.
@ -595,7 +591,7 @@ class Settings():
logger.debug("Skipping unsupported layer: %s %s", layer["name"], dtype) logger.debug("Skipping unsupported layer: %s %s", layer["name"], dtype)
return retval return retval
def _switch_precision(self, layers: T.List[dict], compatible: T.List[str]) -> None: def _switch_precision(self, layers: list[dict], compatible: list[str]) -> None:
""" Switch a model's datatype between mixed-float16 and float32. """ Switch a model's datatype between mixed-float16 and float32.
Parameters Parameters
@ -624,9 +620,9 @@ class Settings():
config["dtype"] = policy config["dtype"] = policy
def get_mixed_precision_layers(self, def get_mixed_precision_layers(self,
build_func: T.Callable[[T.List[tf.keras.layers.Layer]], build_func: Callable[[list[tf.keras.layers.Layer]],
tf.keras.models.Model], tf.keras.models.Model],
inputs: T.List[tf.keras.layers.Layer]) -> T.List[str]: inputs: list[tf.keras.layers.Layer]) -> list[str]:
""" Get and store the mixed precision layers from a full precision enabled model. """ Get and store the mixed precision layers from a full precision enabled model.
Parameters Parameters
@ -699,7 +695,7 @@ class Settings():
del model del model
return new_model return new_model
def strategy_scope(self) -> T.ContextManager: def strategy_scope(self) -> ContextManager:
""" Return the strategy scope if we have set a strategy, otherwise return a null """ Return the strategy scope if we have set a strategy, otherwise return a null
context. context.

View file

@ -4,7 +4,6 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
import typing as T import typing as T
from dataclasses import dataclass from dataclasses import dataclass
@ -27,16 +26,10 @@ from lib.utils import get_tf_version, FaceswapError
from ._base import ModelBase, get_all_sub_models from ._base import ModelBase, get_all_sub_models
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from tensorflow import keras from tensorflow import keras
from tensorflow import Tensor from tensorflow import Tensor
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@ -65,14 +58,14 @@ class _EncoderInfo:
""" """
keras_name: str keras_name: str
default_size: int default_size: int
tf_min: T.Tuple[int, int] = (2, 0) tf_min: tuple[int, int] = (2, 0)
scaling: T.Tuple[int, int] = (0, 1) scaling: tuple[int, int] = (0, 1)
min_size: int = 32 min_size: int = 32
enforce_for_weights: bool = False enforce_for_weights: bool = False
color_order: Literal["bgr", "rgb"] = "rgb" color_order: T.Literal["bgr", "rgb"] = "rgb"
_MODEL_MAPPING: T.Dict[str, _EncoderInfo] = { _MODEL_MAPPING: dict[str, _EncoderInfo] = {
"densenet121": _EncoderInfo( "densenet121": _EncoderInfo(
keras_name="DenseNet121", default_size=224), keras_name="DenseNet121", default_size=224),
"densenet169": _EncoderInfo( "densenet169": _EncoderInfo(
@ -238,7 +231,7 @@ class Model(ModelBase):
model = new_model model = new_model
return model return model
def _select_freeze_layers(self) -> T.List[str]: def _select_freeze_layers(self) -> list[str]:
""" Process the selected frozen layers and replace the `keras_encoder` option with the """ Process the selected frozen layers and replace the `keras_encoder` option with the
actual keras model name actual keras model name
@ -262,7 +255,7 @@ class Model(ModelBase):
logger.debug("Removing 'keras_encoder' for '%s'", arch) logger.debug("Removing 'keras_encoder' for '%s'", arch)
return retval return retval
def _get_input_shape(self) -> T.Tuple[int, int, int]: def _get_input_shape(self) -> tuple[int, int, int]:
""" Obtain the input shape for the model. """ Obtain the input shape for the model.
Input shape is calculated from the selected Encoder's input size, scaled to the user Input shape is calculated from the selected Encoder's input size, scaled to the user
@ -316,7 +309,7 @@ class Model(ModelBase):
f"minimum version required is {tf_min} whilst you have version " f"minimum version required is {tf_min} whilst you have version "
f"{tf_ver} installed.") f"{tf_ver} installed.")
def build_model(self, inputs: T.List[Tensor]) -> keras.models.Model: def build_model(self, inputs: list[Tensor]) -> keras.models.Model:
""" Create the model's structure. """ Create the model's structure.
Parameters Parameters
@ -341,7 +334,7 @@ class Model(ModelBase):
autoencoder = KModel(inputs, outputs, name=self.model_name) autoencoder = KModel(inputs, outputs, name=self.model_name)
return autoencoder return autoencoder
def _build_encoders(self, inputs: T.List[Tensor]) -> T.Dict[str, keras.models.Model]: def _build_encoders(self, inputs: list[Tensor]) -> dict[str, keras.models.Model]:
""" Build the encoders for Phaze-A """ Build the encoders for Phaze-A
Parameters Parameters
@ -362,7 +355,7 @@ class Model(ModelBase):
def _build_fully_connected( def _build_fully_connected(
self, self,
inputs: T.Dict[str, keras.models.Model]) -> T.Dict[str, T.List[keras.models.Model]]: inputs: dict[str, keras.models.Model]) -> dict[str, list[keras.models.Model]]:
""" Build the fully connected layers for Phaze-A """ Build the fully connected layers for Phaze-A
Parameters Parameters
@ -407,8 +400,8 @@ class Model(ModelBase):
def _build_g_blocks( def _build_g_blocks(
self, self,
inputs: T.Dict[str, T.List[keras.models.Model]] inputs: dict[str, list[keras.models.Model]]
) -> T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]: ) -> dict[str, list[keras.models.Model] | keras.models.Model]:
""" Build the g-block layers for Phaze-A. """ Build the g-block layers for Phaze-A.
If a g-block has not been selected for this model, then the original `inters` models are If a g-block has not been selected for this model, then the original `inters` models are
@ -440,10 +433,9 @@ class Model(ModelBase):
logger.debug("G-Blocks: %s", retval) logger.debug("G-Blocks: %s", retval)
return retval return retval
def _build_decoders( def _build_decoders(self,
self, inputs: dict[str, list[keras.models.Model] | keras.models.Model]
inputs: T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]] ) -> dict[str, keras.models.Model]:
) -> T.Dict[str, keras.models.Model]:
""" Build the encoders for Phaze-A """ Build the encoders for Phaze-A
Parameters Parameters
@ -519,12 +511,12 @@ def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str)
return var_x return var_x
def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny", "upscale_fast", def _get_upscale_layer(method: T.Literal["resize_images", "subpixel", "upscale_dny",
"upscale_hybrid", "upsample2d"], "upscale_fast", "upscale_hybrid", "upsample2d"],
filters: int, filters: int,
activation: T.Optional[str] = None, activation: str | None = None,
upsamples: T.Optional[int] = None, upsamples: int | None = None,
interpolation: T.Optional[str] = None) -> keras.layers.Layer: interpolation: str | None = None) -> keras.layers.Layer:
""" Obtain an instance of the requested upscale method. """ Obtain an instance of the requested upscale method.
Parameters Parameters
@ -550,7 +542,7 @@ def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny
The selected configured upscale layer The selected configured upscale layer
""" """
if method == "upsample2d": if method == "upsample2d":
kwargs: T.Dict[str, T.Union[str, int]] = {} kwargs: dict[str, str | int] = {}
if upsamples: if upsamples:
kwargs["size"] = upsamples kwargs["size"] = upsamples
if interpolation: if interpolation:
@ -571,7 +563,7 @@ def _get_curve(start_y: int,
end_y: int, end_y: int,
num_points: int, num_points: int,
scale: float, scale: float,
mode: Literal["full", "cap_max", "cap_min"] = "full") -> T.List[int]: mode: T.Literal["full", "cap_max", "cap_min"] = "full") -> list[int]:
""" Obtain a curve. """ Obtain a curve.
For the given start and end y values, return the y co-ordinates of a curve for the given For the given start and end y values, return the y co-ordinates of a curve for the given
@ -660,13 +652,13 @@ class Encoder(): # pylint:disable=too-few-public-methods
config: dict config: dict
The model configuration options The model configuration options
""" """
def __init__(self, input_shape: T.Tuple[int, ...], config: dict) -> None: def __init__(self, input_shape: tuple[int, ...], config: dict) -> None:
self.input_shape = input_shape self.input_shape = input_shape
self._config = config self._config = config
self._input_shape = input_shape self._input_shape = input_shape
@property @property
def _model_kwargs(self) -> T.Dict[str, T.Dict[str, T.Union[str, bool]]]: def _model_kwargs(self) -> dict[str, dict[str, str | bool]]:
""" dict: Configuration option for architecture mapped to optional kwargs. """ """ dict: Configuration option for architecture mapped to optional kwargs. """
return {"mobilenet": {"alpha": self._config["mobilenet_width"], return {"mobilenet": {"alpha": self._config["mobilenet_width"],
"depth_multiplier": self._config["mobilenet_depth"], "depth_multiplier": self._config["mobilenet_depth"],
@ -677,7 +669,7 @@ class Encoder(): # pylint:disable=too-few-public-methods
"include_preprocessing": False}} "include_preprocessing": False}}
@property @property
def _selected_model(self) -> T.Tuple[_EncoderInfo, dict]: def _selected_model(self) -> tuple[_EncoderInfo, dict]:
""" tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated """ tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated
keyword arguments """ keyword arguments """
arch = self._config["enc_architecture"] arch = self._config["enc_architecture"]
@ -832,7 +824,7 @@ class FullyConnected(): # pylint:disable=too-few-public-methods
The user configuration dictionary The user configuration dictionary
""" """
def __init__(self, def __init__(self,
side: Literal["a", "b", "both", "gblock", "shared"], side: T.Literal["a", "b", "both", "gblock", "shared"],
input_shape: tuple, input_shape: tuple,
config: dict) -> None: config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shape: %s)", logger.debug("Initializing: %s (side: %s, input_shape: %s)",
@ -992,12 +984,12 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will
generate the layers from the starting index to the final upscale. Default: ``None`` generate the layers from the starting index to the final upscale. Default: ``None``
""" """
_filters: T.List[int] = [] _filters: list[int] = []
def __init__(self, def __init__(self,
side: Literal["a", "b", "both", "shared"], side: T.Literal["a", "b", "both", "shared"],
config: dict, config: dict,
layer_indicies: T.Optional[T.Tuple[int, int]] = None) -> None: layer_indicies: tuple[int, int] | None = None) -> None:
logger.debug("Initializing: %s (side: %s, layer_indicies: %s)", logger.debug("Initializing: %s (side: %s, layer_indicies: %s)",
self.__class__.__name__, side, layer_indicies) self.__class__.__name__, side, layer_indicies)
self._side = side self._side = side
@ -1126,7 +1118,7 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
relu_alpha=0.2)(var_x) relu_alpha=0.2)(var_x)
return var_x return var_x
def __call__(self, inputs: T.Union[Tensor, T.List[Tensor]]) -> T.Union[Tensor, T.List[Tensor]]: def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]:
""" Upscale Network. """ Upscale Network.
Parameters Parameters
@ -1203,8 +1195,8 @@ class GBlock(): # pylint:disable=too-few-public-methods
The user configuration dictionary The user configuration dictionary
""" """
def __init__(self, def __init__(self,
side: Literal["a", "b", "both"], side: T.Literal["a", "b", "both"],
input_shapes: T.Union[list, tuple], input_shapes: list | tuple,
config: dict) -> None: config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shapes: %s)", logger.debug("Initializing: %s (side: %s, input_shapes: %s)",
self.__class__.__name__, side, input_shapes) self.__class__.__name__, side, input_shapes)
@ -1284,8 +1276,8 @@ class Decoder(): # pylint:disable=too-few-public-methods
The user configuration dictionary The user configuration dictionary
""" """
def __init__(self, def __init__(self,
side: Literal["a", "b", "both"], side: T.Literal["a", "b", "both"],
input_shape: T.Tuple[int, int, int], input_shape: tuple[int, int, int],
config: dict) -> None: config: dict) -> None:
logger.debug("Initializing: %s (side: %s, input_shape: %s)", logger.debug("Initializing: %s (side: %s, input_shape: %s)",
self.__class__.__name__, side, input_shape) self.__class__.__name__, side, input_shape)

View file

@ -39,8 +39,6 @@
" the value saved in the state file with the updated value in config. If not " the value saved in the state file with the updated value in config. If not
" provided this will default to True. " provided this will default to True.
""" """
from typing import List
_HELPTEXT: str = ( _HELPTEXT: str = (
"Phaze-A Model by TorzDF, with thanks to BirbFakes.\n" "Phaze-A Model by TorzDF, with thanks to BirbFakes.\n"
@ -48,7 +46,7 @@ _HELPTEXT: str = (
"inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to " "inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to "
"understand the parameters better.") "understand the parameters better.")
_ENCODERS: List[str] = sorted([ _ENCODERS: list[str] = sorted([
"densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1", "densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1",
"efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6",
"efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2", "efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2",

View file

@ -9,7 +9,6 @@ with "original" unique code split out to the original plugin.
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
import sys
import time import time
import typing as T import typing as T
@ -23,23 +22,19 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
from lib.image import hex_to_rgb from lib.image import hex_to_rgb
from lib.training import PreviewDataGenerator, TrainingDataGenerator from lib.training import PreviewDataGenerator, TrainingDataGenerator
from lib.training.generator import BatchType, DataGenerator from lib.training.generator import BatchType, DataGenerator
from lib.utils import FaceswapError, get_folder, get_image_paths, get_tf_version from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.train._config import Config from plugins.train._config import Config
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from collections.abc import Callable, Generator
from plugins.train.model._base import ModelBase from plugins.train.model._base import ModelBase
from lib.config import ConfigValueType from lib.config import ConfigValueType
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def _get_config(plugin_name: str, def _get_config(plugin_name: str,
configfile: T.Optional[str] = None) -> T.Dict[str, ConfigValueType]: configfile: str | None = None) -> dict[str, ConfigValueType]:
""" Return the configuration for the requested trainer. """ Return the configuration for the requested trainer.
Parameters Parameters
@ -80,9 +75,9 @@ class TrainerBase():
def __init__(self, def __init__(self,
model: ModelBase, model: ModelBase,
images: T.Dict[Literal["a", "b"], T.List[str]], images: dict[T.Literal["a", "b"], list[str]],
batch_size: int, batch_size: int,
configfile: T.Optional[str]) -> None: configfile: str | None) -> None:
logger.debug("Initializing %s: (model: '%s', batch_size: %s)", logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
self.__class__.__name__, model, batch_size) self.__class__.__name__, model, batch_size)
self._model = model self._model = model
@ -111,7 +106,7 @@ class TrainerBase():
self._images) self._images)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
def _get_config(self, configfile: T.Optional[str]) -> T.Dict[str, ConfigValueType]: def _get_config(self, configfile: str | None) -> dict[str, ConfigValueType]:
""" Get the saved training config options. Override any global settings with the setting """ Get the saved training config options. Override any global settings with the setting
provided from the model's saved config. provided from the model's saved config.
@ -173,10 +168,9 @@ class TrainerBase():
self._samples.toggle_mask_display() self._samples.toggle_mask_display()
def train_one_step(self, def train_one_step(self,
viewer: T.Optional[T.Callable[[np.ndarray, str], None]], viewer: Callable[[np.ndarray, str], None] | None,
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a", timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
"input_b", str] | None) -> None:
"output"], str]]) -> None:
""" Running training on a batch of images for each side. """ Running training on a batch of images for each side.
Triggered from the training cycle in :class:`scripts.train.Train`. Triggered from the training cycle in :class:`scripts.train.Train`.
@ -217,7 +211,7 @@ class TrainerBase():
model_inputs, model_targets = self._feeder.get_batch() model_inputs, model_targets = self._feeder.get_batch()
try: try:
loss: T.List[float] = self._model.model.train_on_batch(model_inputs, y=model_targets) loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
except tf_errors.ResourceExhaustedError as err: except tf_errors.ResourceExhaustedError as err:
msg = ("You do not have enough GPU memory available to train the selected model at " msg = ("You do not have enough GPU memory available to train the selected model at "
"the selected settings. You can try a number of things:" "the selected settings. You can try a number of things:"
@ -236,7 +230,7 @@ class TrainerBase():
self._model.snapshot() self._model.snapshot()
self._update_viewers(viewer, timelapse_kwargs) self._update_viewers(viewer, timelapse_kwargs)
def _log_tensorboard(self, loss: T.List[float]) -> None: def _log_tensorboard(self, loss: list[float]) -> None:
""" Log current loss to Tensorboard log files """ Log current loss to Tensorboard log files
Parameters Parameters
@ -250,19 +244,18 @@ class TrainerBase():
logs = {log[0]: log[1] logs = {log[0]: log[1]
for log in zip(self._model.state.loss_names, loss)} for log in zip(self._model.state.loss_names, loss)}
if get_tf_version() > (2, 7): # Bug in TF 2.8/2.9/2.10 where batch recording got deleted.
# Bug in TF 2.8/2.9/2.10 where batch recording got deleted. # ref: https://github.com/keras-team/keras/issues/16173
# ref: https://github.com/keras-team/keras/issues/16173 with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa:E501 pylint:disable=protected-access,not-context-manager
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa pylint:disable=protected-access,not-context-manager for name, value in logs.items():
for name, value in logs.items(): tf.summary.scalar(
tf.summary.scalar( "batch_" + name,
"batch_" + name, value,
value, step=self._tensorboard._train_step) # pylint:disable=protected-access
step=self._tensorboard._train_step) # pylint:disable=protected-access # TODO revert this code if fixed in tensorflow
else: # self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
def _collate_and_store_loss(self, loss: T.List[float]) -> T.List[float]: def _collate_and_store_loss(self, loss: list[float]) -> list[float]:
""" Collate the loss into totals for each side. """ Collate the loss into totals for each side.
The losses are summed into a total for each side. Loss totals are added to The losses are summed into a total for each side. Loss totals are added to
@ -298,7 +291,7 @@ class TrainerBase():
logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore
return combined_loss return combined_loss
def _print_loss(self, loss: T.List[float]) -> None: def _print_loss(self, loss: list[float]) -> None:
""" Outputs the loss for the current iteration to the console. """ Outputs the loss for the current iteration to the console.
Parameters Parameters
@ -318,10 +311,9 @@ class TrainerBase():
"line: %s, error: %s", output, str(err)) "line: %s, error: %s", output, str(err))
def _update_viewers(self, def _update_viewers(self,
viewer: T.Optional[T.Callable[[np.ndarray, str], None]], viewer: Callable[[np.ndarray, str], None] | None,
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a", timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
"input_b", str] | None) -> None:
"output"], str]]) -> None:
""" Update the preview viewer and timelapse output """ Update the preview viewer and timelapse output
Parameters Parameters
@ -371,10 +363,10 @@ class _Feeder():
The configuration for this trainer The configuration for this trainer
""" """
def __init__(self, def __init__(self,
images: T.Dict[Literal["a", "b"], T.List[str]], images: dict[T.Literal["a", "b"], list[str]],
model: ModelBase, model: ModelBase,
batch_size: int, batch_size: int,
config: T.Dict[str, ConfigValueType]) -> None: config: dict[str, ConfigValueType]) -> None:
logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)", logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)",
self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size, self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size,
config) config)
@ -383,16 +375,16 @@ class _Feeder():
self._batch_size = batch_size self._batch_size = batch_size
self._config = config self._config = config
self._feeds = {side: self._load_generator(side, False).minibatch_ab() self._feeds = {side: self._load_generator(side, False).minibatch_ab()
for side in get_args(Literal["a", "b"])} for side in T.get_args(T.Literal["a", "b"])}
self._display_feeds = {"preview": self._set_preview_feed(), "timelapse": {}} self._display_feeds = {"preview": self._set_preview_feed(), "timelapse": {}}
logger.debug("Initialized %s:", self.__class__.__name__) logger.debug("Initialized %s:", self.__class__.__name__)
def _load_generator(self, def _load_generator(self,
side: Literal["a", "b"], side: T.Literal["a", "b"],
is_display: bool, is_display: bool,
batch_size: T.Optional[int] = None, batch_size: int | None = None,
images: T.Optional[T.List[str]] = None) -> DataGenerator: images: list[str] | None = None) -> DataGenerator:
""" Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder. """ Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder.
Parameters Parameters
@ -424,7 +416,7 @@ class _Feeder():
self._batch_size if batch_size is None else batch_size) self._batch_size if batch_size is None else batch_size)
return retval return retval
def _set_preview_feed(self) -> T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]]: def _set_preview_feed(self) -> dict[T.Literal["a", "b"], Generator[BatchType, None, None]]:
""" Set the preview feed for this feeder. """ Set the preview feed for this feeder.
Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically
@ -436,10 +428,10 @@ class _Feeder():
The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as
value. value.
""" """
retval: T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]] = {} retval: dict[T.Literal["a", "b"], Generator[BatchType, None, None]] = {}
num_images = self._config.get("preview_images", 14) num_images = self._config.get("preview_images", 14)
assert isinstance(num_images, int) assert isinstance(num_images, int)
for side in get_args(Literal["a", "b"]): for side in T.get_args(T.Literal["a", "b"]):
logger.debug("Setting preview feed: (side: '%s')", side) logger.debug("Setting preview feed: (side: '%s')", side)
preview_images = min(max(num_images, 2), 16) preview_images = min(max(num_images, 2), 16)
batchsize = min(len(self._images[side]), preview_images) batchsize = min(len(self._images[side]), preview_images)
@ -448,7 +440,7 @@ class _Feeder():
batch_size=batchsize).minibatch_ab() batch_size=batchsize).minibatch_ab()
return retval return retval
def get_batch(self) -> T.Tuple[T.List[T.List[np.ndarray]], ...]: def get_batch(self) -> tuple[list[list[np.ndarray]], ...]:
""" Get the feed data and the targets for each training side for feeding into the model's """ Get the feed data and the targets for each training side for feeding into the model's
train function. train function.
@ -459,8 +451,8 @@ class _Feeder():
model_targets: list model_targets: list
The targets for the model for each side A and B The targets for the model for each side A and B
""" """
model_inputs: T.List[T.List[np.ndarray]] = [] model_inputs: list[list[np.ndarray]] = []
model_targets: T.List[T.List[np.ndarray]] = [] model_targets: list[list[np.ndarray]] = []
for side in ("a", "b"): for side in ("a", "b"):
side_feed, side_targets = next(self._feeds[side]) side_feed, side_targets = next(self._feeds[side])
if self._model.config["learn_mask"]: # Add the face mask as it's own target if self._model.config["learn_mask"]: # Add the face mask as it's own target
@ -473,7 +465,7 @@ class _Feeder():
return model_inputs, model_targets return model_inputs, model_targets
def generate_preview(self, is_timelapse: bool = False def generate_preview(self, is_timelapse: bool = False
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]: ) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
""" Generate the images for preview window or timelapse """ Generate the images for preview window or timelapse
Parameters Parameters
@ -490,15 +482,15 @@ class _Feeder():
""" """
logger.debug("Generating preview (is_timelapse: %s)", is_timelapse) logger.debug("Generating preview (is_timelapse: %s)", is_timelapse)
batchsizes: T.List[int] = [] batchsizes: list[int] = []
feed: T.Dict[Literal["a", "b"], np.ndarray] = {} feed: dict[T.Literal["a", "b"], np.ndarray] = {}
samples: T.Dict[Literal["a", "b"], np.ndarray] = {} samples: dict[T.Literal["a", "b"], np.ndarray] = {}
masks: T.Dict[Literal["a", "b"], np.ndarray] = {} masks: dict[T.Literal["a", "b"], np.ndarray] = {}
# MyPy can't recurse into nested dicts to get the type :( # MyPy can't recurse into nested dicts to get the type :(
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]], iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
self._display_feeds["timelapse" if is_timelapse else "preview"]) self._display_feeds["timelapse" if is_timelapse else "preview"])
for side in get_args(Literal["a", "b"]): for side in T.get_args(T.Literal["a", "b"]):
side_feed, side_samples = next(iterator[side]) side_feed, side_samples = next(iterator[side])
batchsizes.append(len(side_samples[0])) batchsizes.append(len(side_samples[0]))
samples[side] = side_samples[0] samples[side] = side_samples[0]
@ -513,10 +505,10 @@ class _Feeder():
def compile_sample(self, def compile_sample(self,
image_count: int, image_count: int,
feed: T.Dict[Literal["a", "b"], np.ndarray], feed: dict[T.Literal["a", "b"], np.ndarray],
samples: T.Dict[Literal["a", "b"], np.ndarray], samples: dict[T.Literal["a", "b"], np.ndarray],
masks: T.Dict[Literal["a", "b"], np.ndarray] masks: dict[T.Literal["a", "b"], np.ndarray]
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]: ) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
""" Compile the preview samples for display. """ Compile the preview samples for display.
Parameters Parameters
@ -542,8 +534,8 @@ class _Feeder():
num_images = self._config.get("preview_images", 14) num_images = self._config.get("preview_images", 14)
assert isinstance(num_images, int) assert isinstance(num_images, int)
num_images = min(image_count, num_images) num_images = min(image_count, num_images)
retval: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {} retval: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
for side in get_args(Literal["a", "b"]): for side in T.get_args(T.Literal["a", "b"]):
logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images) logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images)
retval[side] = [feed[side][0:num_images], retval[side] = [feed[side][0:num_images],
samples[side][0:num_images], samples[side][0:num_images],
@ -552,7 +544,7 @@ class _Feeder():
return retval return retval
def set_timelapse_feed(self, def set_timelapse_feed(self,
images: T.Dict[Literal["a", "b"], T.List[str]], images: dict[T.Literal["a", "b"], list[str]],
batch_size: int) -> None: batch_size: int) -> None:
""" Set the time-lapse feed for this feeder. """ Set the time-lapse feed for this feeder.
@ -570,10 +562,10 @@ class _Feeder():
images, batch_size) images, batch_size)
# MyPy can't recurse into nested dicts to get the type :( # MyPy can't recurse into nested dicts to get the type :(
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]], iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
self._display_feeds["timelapse"]) self._display_feeds["timelapse"])
for side in get_args(Literal["a", "b"]): for side in T.get_args(T.Literal["a", "b"]):
imgs = images[side] imgs = images[side]
logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs)) logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs))
@ -615,7 +607,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color) self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color)
self._model = model self._model = model
self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"] self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"]
self.images: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {} self.images: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
self._coverage_ratio = coverage_ratio self._coverage_ratio = coverage_ratio
self._mask_opacity = mask_opacity / 100.0 self._mask_opacity = mask_opacity / 100.0
self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255. self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255.
@ -639,8 +631,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
A compiled preview image ready for display or saving A compiled preview image ready for display or saving
""" """
logger.debug("Showing sample") logger.debug("Showing sample")
feeds: T.Dict[Literal["a", "b"], np.ndarray] = {} feeds: dict[T.Literal["a", "b"], np.ndarray] = {}
for idx, side in enumerate(get_args(Literal["a", "b"])): for idx, side in enumerate(T.get_args(T.Literal["a", "b"])):
feed = self.images[side][0] feed = self.images[side][0]
input_shape = self._model.model.input_shape[idx][1:] input_shape = self._model.model.input_shape[idx][1:]
if input_shape[0] / feed.shape[1] != 1.0: if input_shape[0] / feed.shape[1] != 1.0:
@ -653,7 +645,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
@classmethod @classmethod
def _resize_sample(cls, def _resize_sample(cls,
side: Literal["a", "b"], side: T.Literal["a", "b"],
sample: np.ndarray, sample: np.ndarray,
target_size: int) -> np.ndarray: target_size: int) -> np.ndarray:
""" Resize a given image to the target size. """ Resize a given image to the target size.
@ -684,7 +676,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
return retval return retval
def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> T.Dict[str, np.ndarray]: def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> dict[str, np.ndarray]:
""" Feed the samples to the model and return predictions """ Feed the samples to the model and return predictions
Parameters Parameters
@ -700,7 +692,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
List of :class:`numpy.ndarray` of predictions received from the model List of :class:`numpy.ndarray` of predictions received from the model
""" """
logger.debug("Getting Predictions") logger.debug("Getting Predictions")
preds: T.Dict[str, np.ndarray] = {} preds: dict[str, np.ndarray] = {}
standard = self._model.model.predict([feed_a, feed_b], verbose=0) standard = self._model.model.predict([feed_a, feed_b], verbose=0)
swapped = self._model.model.predict([feed_b, feed_a], verbose=0) swapped = self._model.model.predict([feed_b, feed_a], verbose=0)
@ -719,7 +711,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
return preds return preds
def _compile_preview(self, predictions: T.Dict[str, np.ndarray]) -> np.ndarray: def _compile_preview(self, predictions: dict[str, np.ndarray]) -> np.ndarray:
""" Compile predictions and images into the final preview image. """ Compile predictions and images into the final preview image.
Parameters Parameters
@ -732,8 +724,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
:class:`numpy.ndarry` :class:`numpy.ndarry`
A compiled preview image ready for display or saving A compiled preview image ready for display or saving
""" """
figures: T.Dict[Literal["a", "b"], np.ndarray] = {} figures: dict[T.Literal["a", "b"], np.ndarray] = {}
headers: T.Dict[Literal["a", "b"], np.ndarray] = {} headers: dict[T.Literal["a", "b"], np.ndarray] = {}
for side, samples in self.images.items(): for side, samples in self.images.items():
other_side = "a" if side == "b" else "b" other_side = "a" if side == "b" else "b"
@ -761,9 +753,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
return np.clip(figure * 255, 0, 255).astype('uint8') return np.clip(figure * 255, 0, 255).astype('uint8')
def _to_full_frame(self, def _to_full_frame(self,
side: Literal["a", "b"], side: T.Literal["a", "b"],
samples: T.List[np.ndarray], samples: list[np.ndarray],
predictions: T.List[np.ndarray]) -> T.List[np.ndarray]: predictions: list[np.ndarray]) -> list[np.ndarray]:
""" Patch targets and prediction images into images of model output size. """ Patch targets and prediction images into images of model output size.
Parameters Parameters
@ -803,10 +795,10 @@ class _Samples(): # pylint:disable=too-few-public-methods
return images return images
def _process_full(self, def _process_full(self,
side: Literal["a", "b"], side: T.Literal["a", "b"],
images: np.ndarray, images: np.ndarray,
prediction_size: int, prediction_size: int,
color: T.Tuple[float, float, float]) -> np.ndarray: color: tuple[float, float, float]) -> np.ndarray:
""" Add a frame overlay to preview images indicating the region of interest. """ Add a frame overlay to preview images indicating the region of interest.
This applies the red border that appears in the preview images. This applies the red border that appears in the preview images.
@ -847,7 +839,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
logger.debug("Overlayed background. Shape: %s", images.shape) logger.debug("Overlayed background. Shape: %s", images.shape)
return images return images
def _compile_masked(self, faces: T.List[np.ndarray], masks: np.ndarray) -> T.List[np.ndarray]: def _compile_masked(self, faces: list[np.ndarray], masks: np.ndarray) -> list[np.ndarray]:
""" Add the mask to the faces for masked preview. """ Add the mask to the faces for masked preview.
Places an opaque red layer over areas of the face that are masked out. Places an opaque red layer over areas of the face that are masked out.
@ -866,7 +858,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
List of :class:`numpy.ndarray` faces with the opaque mask layer applied List of :class:`numpy.ndarray` faces with the opaque mask layer applied
""" """
orig_masks = 1 - np.rint(masks) orig_masks = 1 - np.rint(masks)
masks3: T.Union[T.List[np.ndarray], np.ndarray] = [] masks3: list[np.ndarray] | np.ndarray = []
if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions
pred_masks = [1 - np.rint(face[..., -1])[..., None] for face in faces[-2:]] pred_masks = [1 - np.rint(face[..., -1])[..., None] for face in faces[-2:]]
@ -875,7 +867,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
else: else:
masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0) masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0)
retval: T.List[np.ndarray] = [] retval: list[np.ndarray] = []
alpha = 1.0 - self._mask_opacity alpha = 1.0 - self._mask_opacity
for previews, compiled_masks in zip(faces, masks3): for previews, compiled_masks in zip(faces, masks3):
overlays = previews.copy() overlays = previews.copy()
@ -910,7 +902,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
return backgrounds return backgrounds
@classmethod @classmethod
def _get_headers(cls, side: Literal["a", "b"], width: int) -> np.ndarray: def _get_headers(cls, side: T.Literal["a", "b"], width: int) -> np.ndarray:
""" Set header row for the final preview frame """ Set header row for the final preview frame
Parameters Parameters
@ -958,8 +950,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
@classmethod @classmethod
def _duplicate_headers(cls, def _duplicate_headers(cls,
headers: T.Dict[Literal["a", "b"], np.ndarray], headers: dict[T.Literal["a", "b"], np.ndarray],
columns: int) -> T.Dict[Literal["a", "b"], np.ndarray]: columns: int) -> dict[T.Literal["a", "b"], np.ndarray]:
""" Duplicate headers for the number of columns displayed for each side. """ Duplicate headers for the number of columns displayed for each side.
Parameters Parameters
@ -1008,7 +1000,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
mask_opacity: int, mask_opacity: int,
mask_color: str, mask_color: str,
feeder: _Feeder, feeder: _Feeder,
image_paths: T.Dict[Literal["a", "b"], T.List[str]]) -> None: image_paths: dict[T.Literal["a", "b"], list[str]]) -> None:
logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, " logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, "
"mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)", "mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)",
self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity, self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity,
@ -1042,8 +1034,8 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
logger.debug("Time-lapse output set to '%s'", self._output_file) logger.debug("Time-lapse output set to '%s'", self._output_file)
# Rewrite paths to pull from the training images so mask and face data can be accessed # Rewrite paths to pull from the training images so mask and face data can be accessed
images: T.Dict[Literal["a", "b"], T.List[str]] = {} images: dict[T.Literal["a", "b"], list[str]] = {}
for side, input_ in zip(get_args(Literal["a", "b"]), (input_a, input_b)): for side, input_ in zip(T.get_args(T.Literal["a", "b"]), (input_a, input_b)):
training_path = os.path.dirname(self._image_paths[side][0]) training_path = os.path.dirname(self._image_paths[side][0])
images[side] = [os.path.join(training_path, os.path.basename(pth)) images[side] = [os.path.join(training_path, os.path.basename(pth))
for pth in get_image_paths(input_)] for pth in get_image_paths(input_)]
@ -1054,7 +1046,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
self._feeder.set_timelapse_feed(images, batchsize) self._feeder.set_timelapse_feed(images, batchsize)
logger.debug("Set up time-lapse") logger.debug("Set up time-lapse")
def output_timelapse(self, timelapse_kwargs: T.Dict[Literal["input_a", def output_timelapse(self, timelapse_kwargs: dict[T.Literal["input_a",
"input_b", "input_b",
"output"], str]) -> None: "output"], str]) -> None:
""" Generate the time-lapse samples and output the created time-lapse to the specified """ Generate the time-lapse samples and output the created time-lapse to the specified
@ -1068,7 +1060,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
""" """
logger.debug("Ouputting time-lapse") logger.debug("Ouputting time-lapse")
if not self._output_file: if not self._output_file:
self._setup(**T.cast(T.Dict[str, str], timelapse_kwargs)) self._setup(**T.cast(dict[str, str], timelapse_kwargs))
logger.debug("Getting time-lapse samples") logger.debug("Getting time-lapse samples")
self._samples.images = self._feeder.generate_preview(is_timelapse=True) self._samples.images = self._feeder.generate_preview(is_timelapse=True)

View file

@ -1,19 +1,14 @@
tqdm>=4.64 # TESTED WITH PY3.10
tqdm>=4.65
psutil>=5.9.0 psutil>=5.9.0
numexpr>=2.7.3; python_version < '3.9' # >=2.8.0 conflicts in Conda numexpr>=2.8.4
numexpr>=2.8.3; python_version >= '3.9' numpy>=1.25.0
opencv-python>=4.6.0.0 opencv-python>=4.7.0.0
pillow>=9.2.0 pillow>=9.4.0
scikit-learn==1.0.2; python_version < '3.9' # AMD needs version 1.0.2 and 1.1.0 not available in Python 3.7 scikit-learn>=1.2.2
scikit-learn>=1.1.0; python_version >= '3.9'
fastcluster>=1.2.6 fastcluster>=1.2.6
matplotlib>=3.4.3,<3.6.0; python_version < '3.9' # >=3.5.0 conflicts in Conda matplotlib>=3.7.1
matplotlib>=3.5.1,<3.6.0; python_version >= '3.9' imageio>=2.26.0
imageio>=2.19.3 imageio-ffmpeg>=0.4.8
imageio-ffmpeg>=0.4.7
ffmpy>=0.3.0 ffmpy>=0.3.0
# Exclude badly numbered Python2 version of nvidia-ml-py
nvidia-ml-py>=11.515,<300
tensorflow-probability<0.17
typing-extensions>=4.0.0
pywin32>=228 ; sys_platform == "win32" pywin32>=228 ; sys_platform == "win32"

View file

@ -1,11 +1,7 @@
protobuf>= 3.19.0,<3.20.0 # TF has started pulling in incompatible protobuf -r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24 tensorflow-macos>=2.10.0,<2.11.0
numpy>=1.21.0,<1.24.0; python_version < '3.8' tensorflow-deps>=2.10.0,<2.11.0
numpy>=1.22.0,<1.24.0; python_version >= '3.8' tensorflow-metal>=0.6.0,<0.7.0
tensorflow-macos>=2.8.0,<2.11.0
tensorflow-deps>=2.8.0,<2.11.0
tensorflow-metal>=0.4.0,<0.7.0
libblas # Conda only
# These next 2 should have been installed, but some users complain of errors # These next 2 should have been installed, but some users complain of errors
decorator decorator
cloudpickle cloudpickle

View file

@ -1,5 +1,2 @@
-r _requirements_base.txt -r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24 tensorflow-cpu>=2.10.0,<2.11.0
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-cpu>=2.7.0,<2.11.0

View file

@ -1,7 +1,4 @@
-r _requirements_base.txt -r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-cpu>=2.10.0,<2.11.0 tensorflow-cpu>=2.10.0,<2.11.0
tensorflow-directml-plugin tensorflow-directml-plugin
comtypes comtypes

View file

@ -1,6 +1,5 @@
-r _requirements_base.txt -r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24 # Exclude badly numbered Python2 version of nvidia-ml-py
numpy>=1.21.0,<1.24.0; python_version < '3.8' nvidia-ml-py>=11.525,<300
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-gpu>=2.7.0,<2.11.0
pynvx==1.0.0 ; sys_platform == "darwin" pynvx==1.0.0 ; sys_platform == "darwin"
tensorflow>=2.10.0,<2.11.0

View file

@ -1,5 +1,2 @@
-r _requirements_base.txt -r _requirements_base.txt
# Pinned TF probability doesn't work with numpy >= 1.24
numpy>=1.21.0,<1.24.0; python_version < '3.8'
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
tensorflow-rocm>=2.10.0,<2.11.0 tensorflow-rocm>=2.10.0,<2.11.0

View file

@ -26,13 +26,9 @@ from lib.utils import FaceswapError, get_folder, get_image_paths
from plugins.extract.pipeline import Extractor, ExtractMedia from plugins.extract.pipeline import Extractor, ExtractMedia
from plugins.plugin_loader import PluginLoader from plugins.plugin_loader import PluginLoader
if sys.version_info < (3, 8):
from typing_extensions import get_args, Literal
else:
from typing import get_args, Literal
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
from collections.abc import Callable
from plugins.convert.writer._base import Output from plugins.convert.writer._base import Output
from plugins.train.model._base import ModelBase from plugins.train.model._base import ModelBase
from lib.align.aligned_face import CenteringType from lib.align.aligned_face import CenteringType
@ -61,8 +57,8 @@ class ConvertItem:
The swapped faces returned from the model's predict function The swapped faces returned from the model's predict function
""" """
inbound: ExtractMedia inbound: ExtractMedia
feed_faces: T.List[AlignedFace] = field(default_factory=list) feed_faces: list[AlignedFace] = field(default_factory=list)
reference_faces: T.List[AlignedFace] = field(default_factory=list) reference_faces: list[AlignedFace] = field(default_factory=list)
swapped_faces: np.ndarray = np.array([]) swapped_faces: np.ndarray = np.array([])
@ -307,8 +303,8 @@ class DiskIO():
# Extractor for on the fly detection # Extractor for on the fly detection
self._extractor = self._load_extractor() self._extractor = self._load_extractor()
self._queues: T.Dict[Literal["load", "save"], EventQueue] = {} self._queues: dict[T.Literal["load", "save"], EventQueue] = {}
self._threads: T.Dict[Literal["load", "save"], MultiThread] = {} self._threads: dict[T.Literal["load", "save"], MultiThread] = {}
self._init_threads() self._init_threads()
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -324,13 +320,13 @@ class DiskIO():
return self._writer.config.get("draw_transparent", False) return self._writer.config.get("draw_transparent", False)
@property @property
def pre_encode(self) -> T.Optional[T.Callable[[np.ndarray], T.List[bytes]]]: def pre_encode(self) -> Callable[[np.ndarray], list[bytes]] | None:
""" python function: Selected writer's pre-encode function, if it has one, """ python function: Selected writer's pre-encode function, if it has one,
otherwise ``None`` """ otherwise ``None`` """
dummy = np.zeros((20, 20, 3), dtype="uint8") dummy = np.zeros((20, 20, 3), dtype="uint8")
test = self._writer.pre_encode(dummy) test = self._writer.pre_encode(dummy)
retval: T.Optional[T.Callable[[np.ndarray], retval: Callable[[np.ndarray],
T.List[bytes]]] = None if test is None else self._writer.pre_encode list[bytes]] | None = None if test is None else self._writer.pre_encode
logger.debug("Writer pre_encode function: %s", retval) logger.debug("Writer pre_encode function: %s", retval)
return retval return retval
@ -384,7 +380,7 @@ class DiskIO():
return PluginLoader.get_converter("writer", self._args.writer)(*args, return PluginLoader.get_converter("writer", self._args.writer)(*args,
configfile=configfile) configfile=configfile)
def _get_frame_ranges(self) -> T.Optional[T.List[T.Tuple[int, int]]]: def _get_frame_ranges(self) -> list[tuple[int, int]] | None:
""" Obtain the frame ranges that are to be converted. """ Obtain the frame ranges that are to be converted.
If frame ranges have been specified, then split the command line formatted arguments into If frame ranges have been specified, then split the command line formatted arguments into
@ -422,7 +418,7 @@ class DiskIO():
logger.debug("frame ranges: %s", retval) logger.debug("frame ranges: %s", retval)
return retval return retval
def _load_extractor(self) -> T.Optional[Extractor]: def _load_extractor(self) -> Extractor | None:
""" Load the CV2-DNN Face Extractor Chain. """ Load the CV2-DNN Face Extractor Chain.
For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU. For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU.
@ -467,12 +463,12 @@ class DiskIO():
Creates the load and save queues and the load and save threads. Starts the threads. Creates the load and save queues and the load and save threads. Starts the threads.
""" """
logger.debug("Initializing DiskIO Threads") logger.debug("Initializing DiskIO Threads")
for task in get_args(Literal["load", "save"]): for task in T.get_args(T.Literal["load", "save"]):
self._add_queue(task) self._add_queue(task)
self._start_thread(task) self._start_thread(task)
logger.debug("Initialized DiskIO Threads") logger.debug("Initialized DiskIO Threads")
def _add_queue(self, task: Literal["load", "save"]) -> None: def _add_queue(self, task: T.Literal["load", "save"]) -> None:
""" Add the queue to queue_manager and to :attr:`self._queues` for the given task. """ Add the queue to queue_manager and to :attr:`self._queues` for the given task.
Parameters Parameters
@ -490,7 +486,7 @@ class DiskIO():
self._queues[task] = queue_manager.get_queue(q_name) self._queues[task] = queue_manager.get_queue(q_name)
logger.debug("Added queue for task: '%s'", task) logger.debug("Added queue for task: '%s'", task)
def _start_thread(self, task: Literal["load", "save"]) -> None: def _start_thread(self, task: T.Literal["load", "save"]) -> None:
""" Create the thread for the given task, add it it :attr:`self._threads` and start it. """ Create the thread for the given task, add it it :attr:`self._threads` and start it.
Parameters Parameters
@ -571,7 +567,7 @@ class DiskIO():
logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore
return skipframe return skipframe
def _get_detected_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]: def _get_detected_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
""" Return the detected faces for the given image. """ Return the detected faces for the given image.
If we have an alignments file, then the detected faces are created from that file. If If we have an alignments file, then the detected faces are created from that file. If
@ -597,7 +593,7 @@ class DiskIO():
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore
return detected_faces return detected_faces
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> T.List[DetectedFace]: def _alignments_faces(self, frame_name: str, image: np.ndarray) -> list[DetectedFace]:
""" Return detected faces from an alignments file. """ Return detected faces from an alignments file.
Parameters Parameters
@ -644,7 +640,7 @@ class DiskIO():
tqdm.write(f"No alignment found for {frame_name}, skipping") tqdm.write(f"No alignment found for {frame_name}, skipping")
return have_alignments return have_alignments
def _detect_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]: def _detect_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
""" Extract the face from a frame for On-The-Fly conversion. """ Extract the face from a frame for On-The-Fly conversion.
Pulls detected faces out of the Extraction pipeline. Pulls detected faces out of the Extraction pipeline.
@ -779,7 +775,7 @@ class Predict():
""" int: The size in pixels of the Faceswap model output. """ """ int: The size in pixels of the Faceswap model output. """
return self._sizes["output"] return self._sizes["output"]
def _get_io_sizes(self) -> T.Dict[str, int]: def _get_io_sizes(self) -> dict[str, int]:
""" Obtain the input size and output size of the model. """ Obtain the input size and output size of the model.
Returns Returns
@ -896,9 +892,9 @@ class Predict():
""" """
faces_seen = 0 faces_seen = 0
consecutive_no_faces = 0 consecutive_no_faces = 0
batch: T.List[ConvertItem] = [] batch: list[ConvertItem] = []
while True: while True:
item: T.Union[Literal["EOF"], ConvertItem] = self._in_queue.get() item: T.Literal["EOF"] | ConvertItem = self._in_queue.get()
if item == "EOF": if item == "EOF":
logger.debug("EOF Received") logger.debug("EOF Received")
if batch: # Process out any remaining items if batch: # Process out any remaining items
@ -938,7 +934,7 @@ class Predict():
self._out_queue.put("EOF") self._out_queue.put("EOF")
logger.debug("Load queue complete") logger.debug("Load queue complete")
def _process_batch(self, batch: T.List[ConvertItem], faces_seen: int): def _process_batch(self, batch: list[ConvertItem], faces_seen: int):
""" Predict faces on the given batch of images and queue out to patch thread """ Predict faces on the given batch of images and queue out to patch thread
Parameters Parameters
@ -1001,7 +997,7 @@ class Predict():
logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore
@staticmethod @staticmethod
def _compile_feed_faces(feed_faces: T.List[AlignedFace]) -> np.ndarray: def _compile_feed_faces(feed_faces: list[AlignedFace]) -> np.ndarray:
""" Compile a batch of faces for feeding into the Predictor. """ Compile a batch of faces for feeding into the Predictor.
Parameters Parameters
@ -1020,7 +1016,7 @@ class Predict():
logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore
return retval return retval
def _predict(self, feed_faces: np.ndarray, batch_size: T.Optional[int] = None) -> np.ndarray: def _predict(self, feed_faces: np.ndarray, batch_size: int | None = None) -> np.ndarray:
""" Run the Faceswap models' prediction function. """ Run the Faceswap models' prediction function.
Parameters Parameters
@ -1045,7 +1041,7 @@ class Predict():
logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore
inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size) inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size)
predicted: T.List[np.ndarray] = inbound if isinstance(inbound, list) else [inbound] predicted: list[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
if self._model.color_order.lower() == "rgb": if self._model.color_order.lower() == "rgb":
predicted[0] = predicted[0][..., ::-1] predicted[0] = predicted[0][..., ::-1]
@ -1062,7 +1058,7 @@ class Predict():
logger.trace("Final shape: %s", retval.shape) # type:ignore logger.trace("Final shape: %s", retval.shape) # type:ignore
return retval return retval
def _queue_out_frames(self, batch: T.List[ConvertItem], swapped_faces: np.ndarray) -> None: def _queue_out_frames(self, batch: list[ConvertItem], swapped_faces: np.ndarray) -> None:
""" Compile the batch back to original frames and put to the Out Queue. """ Compile the batch back to original frames and put to the Out Queue.
For batching, faces are split away from their frames. This compiles all detected faces For batching, faces are split away from their frames. This compiles all detected faces
@ -1108,7 +1104,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
""" """
def __init__(self, def __init__(self,
arguments: Namespace, arguments: Namespace,
input_images: T.List[np.ndarray], input_images: list[np.ndarray],
alignments: Alignments) -> None: alignments: Alignments) -> None:
logger.debug("Initializing %s", self.__class__.__name__) logger.debug("Initializing %s", self.__class__.__name__)
self._args = arguments self._args = arguments
@ -1131,7 +1127,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
self._alignments.filter_faces(accept_dict, filter_out=False) self._alignments.filter_faces(accept_dict, filter_out=False)
logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count) logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count)
def _get_face_metadata(self) -> T.Dict[str, T.List[int]]: def _get_face_metadata(self) -> dict[str, list[int]]:
""" Check for the existence of an aligned directory for identifying which faces in the """ Check for the existence of an aligned directory for identifying which faces in the
target frames should be swapped. If it exists, scan the folder for face's metadata target frames should be swapped. If it exists, scan the folder for face's metadata
@ -1140,7 +1136,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
dict dict
Dictionary of source frame names with a list of associated face indices to be skipped Dictionary of source frame names with a list of associated face indices to be skipped
""" """
retval: T.Dict[str, T.List[int]] = {} retval: dict[str, list[int]] = {}
input_aligned_dir = self._args.input_aligned_dir input_aligned_dir = self._args.input_aligned_dir
if input_aligned_dir is None: if input_aligned_dir is None:

View file

@ -2,13 +2,13 @@
""" Main entry point to the extract process of FaceSwap """ """ Main entry point to the extract process of FaceSwap """
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
import sys import sys
import typing as T
from argparse import Namespace from argparse import Namespace
from multiprocessing import Process from multiprocessing import Process
from typing import List, Dict, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
@ -20,7 +20,7 @@ from lib.utils import get_folder, _image_extensions, _video_extensions
from plugins.extract.pipeline import Extractor, ExtractMedia from plugins.extract.pipeline import Extractor, ExtractMedia
from scripts.fsmedia import Alignments, PostProcess, finalize from scripts.fsmedia import Alignments, PostProcess, finalize
if TYPE_CHECKING: if T.TYPE_CHECKING:
from lib.align.alignments import PNGHeaderAlignmentsDict from lib.align.alignments import PNGHeaderAlignmentsDict
# tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO? # tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO?
@ -75,7 +75,7 @@ class Extract(): # pylint:disable=too-few-public-methods
self._args.nfilter, self._args.nfilter,
self._extractor) self._extractor)
def _get_input_locations(self) -> List[str]: def _get_input_locations(self) -> list[str]:
""" Obtain the full path to input locations. Will be a list of locations if batch mode is """ Obtain the full path to input locations. Will be a list of locations if batch mode is
selected, or a containing a single location if batch mode is not selected. selected, or a containing a single location if batch mode is not selected.
@ -194,8 +194,8 @@ class Filter():
""" """
def __init__(self, def __init__(self,
threshold: float, threshold: float,
filter_files: Optional[List[str]], filter_files: list[str] | None,
nfilter_files: Optional[List[str]], nfilter_files: list[str] | None,
extractor: Extractor) -> None: extractor: Extractor) -> None:
logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s " logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s "
"extractor: %s)", self.__class__.__name__, threshold, filter_files, "extractor: %s)", self.__class__.__name__, threshold, filter_files,
@ -208,8 +208,8 @@ class Filter():
logger.debug("Filter not selected. Exiting %s", self.__class__.__name__) logger.debug("Filter not selected. Exiting %s", self.__class__.__name__)
return return
self._embeddings: List[np.ndarray] = [np.array([]) for _ in self._filter_files] self._embeddings: list[np.ndarray] = [np.array([]) for _ in self._filter_files]
self._nembeddings: List[np.ndarray] = [np.array([]) for _ in self._nfilter_files] self._nembeddings: list[np.ndarray] = [np.array([]) for _ in self._nfilter_files]
self._extractor = extractor self._extractor = extractor
self._get_embeddings() self._get_embeddings()
@ -243,7 +243,7 @@ class Filter():
return retval return retval
@classmethod @classmethod
def _files_from_folder(cls, input_location: List[str]) -> List[str]: def _files_from_folder(cls, input_location: list[str]) -> list[str]:
""" Test whether the input location is a folder and if so, return the list of contained """ Test whether the input location is a folder and if so, return the list of contained
image files, otherwise return the original input location image files, otherwise return the original input location
@ -274,8 +274,8 @@ class Filter():
return retval return retval
def _validate_inputs(self, def _validate_inputs(self,
filter_files: Optional[List[str]], filter_files: list[str] | None,
nfilter_files: Optional[List[str]]) -> Tuple[List[str], List[str]]: nfilter_files: list[str] | None) -> tuple[list[str], list[str]]:
""" Validates that the given filter/nfilter files exist, are image files and are unique """ Validates that the given filter/nfilter files exist, are image files and are unique
Parameters Parameters
@ -293,7 +293,7 @@ class Filter():
List of full paths to nfilter files List of full paths to nfilter files
""" """
error = False error = False
retval: List[List[str]] = [] retval: list[list[str]] = []
for files in (filter_files, nfilter_files): for files in (filter_files, nfilter_files):
filt_files = [] if files is None else self._files_from_folder(files) filt_files = [] if files is None else self._files_from_folder(files)
@ -322,7 +322,7 @@ class Filter():
return filters, nfilters return filters, nfilters
@classmethod @classmethod
def _identity_from_extracted(cls, filename) -> Tuple[np.ndarray, bool]: def _identity_from_extracted(cls, filename) -> tuple[np.ndarray, bool]:
""" Test whether the given image is a faceswap extracted face and contains identity """ Test whether the given image is a faceswap extracted face and contains identity
information. If so, return the identity embedding information. If so, return the identity embedding
@ -404,7 +404,7 @@ class Filter():
embeddings[idx] = identities embeddings[idx] = identities
return return
def _identity_from_extractor(self, file_list: List[str], aligned: List[str]) -> None: def _identity_from_extractor(self, file_list: list[str], aligned: list[str]) -> None:
""" Obtain the identity embeddings from the extraction pipeline """ Obtain the identity embeddings from the extraction pipeline
Parameters Parameters
@ -425,7 +425,7 @@ class Filter():
for phase in range(self._extractor.passes): for phase in range(self._extractor.passes):
is_final = self._extractor.final_pass is_final = self._extractor.final_pass
detected_faces: Dict[str, ExtractMedia] = {} detected_faces: dict[str, ExtractMedia] = {}
self._extractor.launch() self._extractor.launch()
desc = "Obtaining reference face Identity" desc = "Obtaining reference face Identity"
if self._extractor.passes > 1: if self._extractor.passes > 1:
@ -450,8 +450,8 @@ class Filter():
def _get_embeddings(self) -> None: def _get_embeddings(self) -> None:
""" Obtain the embeddings for the given filter lists """ """ Obtain the embeddings for the given filter lists """
needs_extraction: List[str] = [] needs_extraction: list[str] = []
aligned: List[str] = [] aligned: list[str] = []
for files, embed in zip((self._filter_files, self._nfilter_files), for files, embed in zip((self._filter_files, self._nfilter_files),
(self._embeddings, self._nembeddings)): (self._embeddings, self._nembeddings)):
@ -494,14 +494,14 @@ class PipelineLoader():
image files that exist in :attr:`path` that are aligned faceswap images image files that exist in :attr:`path` that are aligned faceswap images
""" """
def __init__(self, def __init__(self,
path: Union[str, List[str]], path: str | list[str],
extractor: Extractor, extractor: Extractor,
aligned_filenames: Optional[List[str]] = None) -> None: aligned_filenames: list[str] | None = None) -> None:
logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)", logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)",
self.__class__.__name__, path, extractor, aligned_filenames) self.__class__.__name__, path, extractor, aligned_filenames)
self._images = ImagesLoader(path, fast_count=True) self._images = ImagesLoader(path, fast_count=True)
self._extractor = extractor self._extractor = extractor
self._threads: List[MultiThread] = [] self._threads: list[MultiThread] = []
self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -512,7 +512,7 @@ class PipelineLoader():
return self._images.is_video return self._images.is_video
@property @property
def file_list(self) -> List[str]: def file_list(self) -> list[str]:
""" list: A full list of files in the source location. If the input is a video """ list: A full list of files in the source location. If the input is a video
then this is a list of dummy filenames as corresponding to an alignments file """ then this is a list of dummy filenames as corresponding to an alignments file """
return self._images.file_list return self._images.file_list
@ -523,7 +523,7 @@ class PipelineLoader():
items that are to be skipped from the :attr:`skip_list`)""" items that are to be skipped from the :attr:`skip_list`)"""
return self._images.process_count return self._images.process_count
def add_skip_list(self, skip_list: List[int]) -> None: def add_skip_list(self, skip_list: list[int]) -> None:
""" Add a skip list to the :class:`ImagesLoader` """ Add a skip list to the :class:`ImagesLoader`
Parameters Parameters
@ -538,7 +538,7 @@ class PipelineLoader():
""" Launch the image loading pipeline """ """ Launch the image loading pipeline """
self._threaded_redirector("load") self._threaded_redirector("load")
def reload(self, detected_faces: Dict[str, ExtractMedia]) -> None: def reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
""" Reload images for multiple pipeline passes """ """ Reload images for multiple pipeline passes """
self._threaded_redirector("reload", (detected_faces, )) self._threaded_redirector("reload", (detected_faces, ))
@ -552,7 +552,7 @@ class PipelineLoader():
for thread in self._threads: for thread in self._threads:
thread.join() thread.join()
def _threaded_redirector(self, task: str, io_args: Optional[tuple] = None) -> None: def _threaded_redirector(self, task: str, io_args: tuple | None = None) -> None:
""" Redirect image input/output tasks to relevant queues in background thread """ Redirect image input/output tasks to relevant queues in background thread
Parameters Parameters
@ -587,7 +587,7 @@ class PipelineLoader():
load_queue.put("EOF") load_queue.put("EOF")
logger.debug("Load Images: Complete") logger.debug("Load Images: Complete")
def _reload(self, detected_faces: Dict[str, ExtractMedia]) -> None: def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
""" Reload the images and pair to detected face """ Reload the images and pair to detected face
When the extraction pipeline is running in serial mode, images are reloaded from disk, When the extraction pipeline is running in serial mode, images are reloaded from disk,
@ -652,7 +652,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@property @property
def _save_interval(self) -> Optional[int]: def _save_interval(self) -> int | None:
""" int: The number of frames to be processed between each saving of the alignments file if """ int: The number of frames to be processed between each saving of the alignments file if
it has been provided, otherwise ``None`` """ it has been provided, otherwise ``None`` """
if hasattr(self._args, "save_interval"): if hasattr(self._args, "save_interval"):
@ -718,7 +718,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
as_bytes=True) as_bytes=True)
for phase in range(self._extractor.passes): for phase in range(self._extractor.passes):
is_final = self._extractor.final_pass is_final = self._extractor.final_pass
detected_faces: Dict[str, ExtractMedia] = {} detected_faces: dict[str, ExtractMedia] = {}
self._extractor.launch() self._extractor.launch()
self._loader.check_thread_error() self._loader.check_thread_error()
ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text
@ -774,7 +774,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
if not self._verify_output and faces_count > 1: if not self._verify_output and faces_count > 1:
self._verify_output = True self._verify_output = True
def _output_faces(self, saver: Optional[ImagesSaver], extract_media: ExtractMedia) -> None: def _output_faces(self, saver: ImagesSaver | None, extract_media: ExtractMedia) -> None:
""" Output faces to save thread """ Output faces to save thread
Set the face filename based on the frame name and put the face to the Set the face filename based on the frame name and put the face to the
@ -798,14 +798,14 @@ class _Extract(): # pylint:disable=too-few-public-methods
output_filename = f"{filename}_{real_face_id}.png" output_filename = f"{filename}_{real_face_id}.png"
aligned = face.aligned.face aligned = face.aligned.face
assert aligned is not None assert aligned is not None
meta: PNGHeaderDict = dict( meta: PNGHeaderDict = {
alignments=face.to_png_meta(), "alignments": face.to_png_meta(),
source=dict(alignments_version=self._alignments.version, "source": {"alignments_version": self._alignments.version,
original_filename=output_filename, "original_filename": output_filename,
face_index=real_face_id, "face_index": real_face_id,
source_filename=os.path.basename(extract_media.filename), "source_filename": os.path.basename(extract_media.filename),
source_is_video=self._loader.is_video, "source_is_video": self._loader.is_video,
source_frame_dims=extract_media.image_size)) "source_frame_dims": extract_media.image_size}}
image = encode_image(aligned, ".png", metadata=meta) image = encode_image(aligned, ".png", metadata=meta)
sub_folder = extract_media.sub_folders[face_id] sub_folder = extract_media.sub_folders[face_id]
@ -820,6 +820,6 @@ class _Extract(): # pylint:disable=too-few-public-methods
continue continue
final_faces.append(face.to_alignment()) final_faces.append(face.to_alignment())
self._alignments.data[os.path.basename(extract_media.filename)] = dict(faces=final_faces, self._alignments.data[os.path.basename(extract_media.filename)] = {"faces": final_faces,
video_meta={}) "video_meta": {}}
del extract_media del extract_media

Some files were not shown because too many files have changed in this diff Show more