mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-06 17:45:56 -04:00
Rebase code (#1326)
* Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test
This commit is contained in:
parent
e4ba12ad2a
commit
6a3b674bef
130 changed files with 3035 additions and 3028 deletions
87
.github/workflows/pytest.yml
vendored
87
.github/workflows/pytest.yml
vendored
|
@ -8,17 +8,85 @@ on:
|
|||
- "**/README.md"
|
||||
|
||||
jobs:
|
||||
build_linux:
|
||||
build_conda:
|
||||
name: conda (${{ matrix.os }}, ${{ matrix.backend }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
|
||||
backend: ["nvidia", "cpu"]
|
||||
include:
|
||||
- os: "ubuntu-latest"
|
||||
backend: "rocm"
|
||||
- os: "windows-latest"
|
||||
backend: "directml"
|
||||
exclude:
|
||||
# pynvx does not currently build for Python3.10 and without CUDA it may not build at all
|
||||
- os: "macos-latest"
|
||||
backend: "nvidia"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set cache date
|
||||
run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV
|
||||
- name: Cache conda
|
||||
uses: actions/cache@v3
|
||||
env:
|
||||
# Increase this value to manually reset cache
|
||||
CACHE_NUMBER: 1
|
||||
REQ_FILE: ./requirements/requirements_${{ matrix.backend }}.txt
|
||||
with:
|
||||
path: ~/conda_pkgs_dir
|
||||
key: ${{ runner.os }}-${{ matrix.backend }}-conda-${{ env.CACHE_NUMBER }}-${{ env.DATE }}-${{ hashFiles('./requirements/requirements.txt', env.REQ_FILE) }}
|
||||
- name: Set up Conda
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
auto-update-conda: true
|
||||
activate-environment: faceswap
|
||||
- name: Conda info
|
||||
run: conda info && conda list
|
||||
- name: Install
|
||||
run: |
|
||||
python setup.py --installer --${{ matrix.backend }}
|
||||
pip install flake8 pylint mypy pytest pytest-mock wheel pytest-xvfb
|
||||
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --select=E9,F63,F7,F82 --show-source
|
||||
flake8 . --exit-zero
|
||||
- name: MyPy Typing
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mypy .
|
||||
- name: SysInfo
|
||||
run: python -c "from lib.sysinfo import sysinfo ; print(sysinfo)"
|
||||
- name: Simple Tests
|
||||
# These backends will fail as GPU drivers not available
|
||||
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml'
|
||||
run: |
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
|
||||
- name: End to End Tests
|
||||
# These backends will fail as GPU drivers not available
|
||||
# macOS fails on first extract test with 'died with <Signals.SIGSEGV: 11>'
|
||||
if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml' && matrix.os != 'macos-latest'
|
||||
run: |
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
|
||||
|
||||
build_linux:
|
||||
name: "pip (ubuntu-latest, ${{ matrix.backend }})"
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.7", "3.8", "3.9"]
|
||||
python-version: ["3.10"]
|
||||
backend: ["cpu"]
|
||||
include:
|
||||
- kbackend: "tensorflow"
|
||||
backend: "cpu"
|
||||
- backend: "cpu"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -33,6 +101,8 @@ jobs:
|
|||
pip install flake8 pylint mypy pytest pytest-mock pytest-xvfb wheel
|
||||
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
|
||||
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
|
||||
- name: List installed packages
|
||||
run: pip freeze
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
|
@ -45,17 +115,18 @@ jobs:
|
|||
mypy .
|
||||
- name: Simple Tests
|
||||
run: |
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" py.test -v tests/;
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/;
|
||||
- name: End to End Tests
|
||||
run: |
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" python tests/simple_tests.py;
|
||||
FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py;
|
||||
|
||||
build_windows:
|
||||
name: "pip (windows-latest, ${{ matrix.backend }})"
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9"]
|
||||
python-version: ["3.10"]
|
||||
backend: ["cpu", "directml"]
|
||||
include:
|
||||
- backend: "cpu"
|
||||
|
@ -74,6 +145,8 @@ jobs:
|
|||
pip install flake8 pylint mypy pytest pytest-mock wheel
|
||||
pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools
|
||||
pip install -r ./requirements/requirements_${{ matrix.backend }}.txt
|
||||
- name: List installed packages
|
||||
run: pip freeze
|
||||
- name: Set Backend EnvVar
|
||||
run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- name: Lint with flake8
|
||||
|
|
|
@ -12,7 +12,7 @@ DIR_CONDA="$HOME/miniconda3"
|
|||
CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda"
|
||||
CONDA_TO_PATH=false
|
||||
ENV_NAME="faceswap"
|
||||
PYENV_VERSION="3.9"
|
||||
PYENV_VERSION="3.10"
|
||||
|
||||
DIR_FACESWAP="$HOME/faceswap"
|
||||
VERSION="nvidia"
|
||||
|
@ -363,7 +363,7 @@ delete_env() {
|
|||
}
|
||||
|
||||
create_env() {
|
||||
# Create Python 3.8 env for faceswap
|
||||
# Create Python 3.10 env for faceswap
|
||||
delete_env
|
||||
info "Creating Conda Virtual Environment..."
|
||||
yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -q python="$PYENV_VERSION" -y
|
||||
|
|
|
@ -22,7 +22,7 @@ InstallDir $PROFILE\faceswap
|
|||
# Install cli flags
|
||||
!define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3"
|
||||
!define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}"
|
||||
!define flagsEnv "-y python=3.9"
|
||||
!define flagsEnv "-y python=3.10"
|
||||
|
||||
# Folders
|
||||
Var ProgramData
|
||||
|
|
|
@ -7,9 +7,9 @@ version: 2
|
|||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-20.04
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.8"
|
||||
python: "3.10"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
FROM tensorflow/tensorflow:2.8.2
|
||||
FROM ubuntu:22.04
|
||||
|
||||
# To disable tzdata and others from asking for input
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
ENV FACESWAP_BACKEND cpu
|
||||
|
||||
RUN apt-get update -qq -y \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository -y ppa:jonathonf/ffmpeg-4 \
|
||||
&& apt-get update -qq -y \
|
||||
&& apt-get install -y libsm6 libxrender1 libxext-dev python3-tk ffmpeg git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update -qq -y
|
||||
RUN apt-get upgrade -y
|
||||
RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
|
||||
|
||||
COPY ./requirements/_requirements_base.txt /opt/
|
||||
RUN pip3 install --upgrade pip
|
||||
RUN pip3 --no-cache-dir install -r /opt/_requirements_base.txt && rm /opt/_requirements_base.txt
|
||||
RUN ln -s $(which python3) /usr/local/bin/python
|
||||
|
||||
RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
|
||||
WORKDIR "/faceswap"
|
||||
|
||||
RUN python -m pip install --upgrade pip
|
||||
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_cpu.txt
|
||||
|
||||
WORKDIR "/srv"
|
||||
CMD ["/bin/bash"]
|
||||
|
|
|
@ -1,29 +1,19 @@
|
|||
FROM nvidia/cuda:11.7.0-runtime-ubuntu18.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
#install python3.8
|
||||
RUN apt-get update
|
||||
RUN apt-get install software-properties-common -y
|
||||
RUN add-apt-repository ppa:deadsnakes/ppa -y
|
||||
RUN apt-get update
|
||||
RUN apt-get install python3.8 -y
|
||||
RUN apt-get install python3.8-distutils -y
|
||||
RUN apt-get install python3.8-tk -y
|
||||
RUN apt-get install curl -y
|
||||
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
RUN python3.8 get-pip.py
|
||||
RUN rm get-pip.py
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV FACESWAP_BACKEND nvidia
|
||||
|
||||
# install requirements
|
||||
RUN apt-get install ffmpeg git -y
|
||||
COPY ./requirements/_requirements_base.txt /opt/
|
||||
COPY ./requirements/requirements_nvidia.txt /opt/
|
||||
RUN python3.8 -m pip --no-cache-dir install -r /opt/requirements_nvidia.txt && rm /opt/_requirements_base.txt && rm /opt/requirements_nvidia.txt
|
||||
RUN apt-get update -qq -y
|
||||
RUN apt-get upgrade -y
|
||||
RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git
|
||||
|
||||
RUN python3.8 -m pip install jupyter matplotlib tqdm
|
||||
RUN python3.8 -m pip install jupyter_http_over_ws
|
||||
RUN jupyter serverextension enable --py jupyter_http_over_ws
|
||||
RUN alias python=python3.8
|
||||
RUN echo "alias python=python3.8" >> /root/.bashrc
|
||||
WORKDIR "/notebooks"
|
||||
CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"]
|
||||
RUN ln -s $(which python3) /usr/local/bin/python
|
||||
|
||||
RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git
|
||||
WORKDIR "/faceswap"
|
||||
|
||||
RUN python -m pip install --upgrade pip
|
||||
RUN python -m pip install --upgrade pip
|
||||
RUN python -m pip --no-cache-dir install -r ./requirements/requirements_nvidia.txt
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
|
217
INSTALL.md
217
INSTALL.md
|
@ -39,12 +39,9 @@
|
|||
- [Setup](#setup-2)
|
||||
- [About some of the options](#about-some-of-the-options)
|
||||
- [Docker Install Guide](#docker-install-guide)
|
||||
- [Docker General](#docker-general)
|
||||
- [CUDA with Docker in 20 minutes.](#cuda-with-docker-in-20-minutes)
|
||||
- [CUDA with Docker on Arch Linux](#cuda-with-docker-on-arch-linux)
|
||||
- [Install docker](#install-docker)
|
||||
- [A successful setup log, without docker.](#a-successful-setup-log-without-docker)
|
||||
- [Run the project](#run-the-project)
|
||||
- [Docker CPU](#docker-cpu)
|
||||
- [Docker Nvidia](#docker-nvidia)
|
||||
- [Run the project](#run-the-project)
|
||||
- [Notes](#notes)
|
||||
|
||||
# Prerequisites
|
||||
|
@ -115,7 +112,7 @@ Reboot your PC, so that everything you have just installed gets registered.
|
|||
- Select "Create" at the bottom
|
||||
- In the pop up:
|
||||
- Give it the name: faceswap
|
||||
- **IMPORTANT**: Select python version 3.8
|
||||
- **IMPORTANT**: Select python version 3.10
|
||||
- Hit "Create" (NB: This may take a while as it will need to download Python)
|
||||

|
||||
|
||||
|
@ -195,7 +192,7 @@ $ source ~/miniforge3/bin/activate
|
|||
## Setup
|
||||
### Create and Activate the Environment
|
||||
```sh
|
||||
$ conda create --name faceswap python=3.9
|
||||
$ conda create --name faceswap python=3.10
|
||||
$ conda activate faceswap
|
||||
```
|
||||
|
||||
|
@ -225,7 +222,7 @@ Obtain git for your distribution from the [git website](https://git-scm.com/down
|
|||
The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project.
|
||||
- MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html)
|
||||
|
||||
Alternatively you can install Python (>= 3.7-3.9 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu).
|
||||
Alternatively you can install Python (3.10 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu).
|
||||
- Python distributions:
|
||||
- apt/yum install python3 (Linux)
|
||||
- [Installer](https://www.python.org/downloads/release/python-368/) (Windows)
|
||||
|
@ -260,153 +257,83 @@ If setup fails for any reason you can still manually install the packages listed
|
|||
|
||||
# Docker Install Guide
|
||||
|
||||
## Docker General
|
||||
<details>
|
||||
<summary>Click to expand!</summary>
|
||||
This Faceswap repo contains Docker build scripts for CPU and Nvidia backends. The scripts will set up a Docker container for you and install the latest version of the Faceswap software.
|
||||
|
||||
### CUDA with Docker in 20 minutes.
|
||||
|
||||
1. Install Docker
|
||||
https://www.docker.com/community-edition
|
||||
You must first ensure that Docker is installed and running on your system. Follow the guide for downloading and installing Docker from their website:
|
||||
|
||||
2. Install Nvidia-Docker & Restart Docker Service
|
||||
https://github.com/NVIDIA/nvidia-docker
|
||||
- https://www.docker.com/get-started
|
||||
|
||||
3. Build Docker Image For faceswap
|
||||
|
||||
```bash
|
||||
docker build -t deepfakes-gpu -f Dockerfile.gpu .
|
||||
```
|
||||
Once Docker is installed and running, follow the relevant steps for your chosen backend
|
||||
## Docker CPU
|
||||
To run the CPU version of Faceswap follow these steps:
|
||||
|
||||
4. Mount faceswap volume and Run it
|
||||
a). without `gui.tools.py` gui not working.
|
||||
|
||||
```bash
|
||||
nvidia-docker run --rm -it -p 8888:8888 \
|
||||
--hostname faceswap-gpu --name faceswap-gpu \
|
||||
-v /opt/faceswap:/srv \
|
||||
deepfakes-gpu
|
||||
```
|
||||
|
||||
b). with gui. tools.py gui working.
|
||||
|
||||
Enable local access to X11 server
|
||||
|
||||
```bash
|
||||
xhost +local:
|
||||
1. Build the Docker image For faceswap:
|
||||
```
|
||||
|
||||
Enable nvidia device if working under bumblebee
|
||||
|
||||
```bash
|
||||
echo ON > /proc/acpi/bbswitch
|
||||
docker build \
|
||||
-t faceswap-cpu \
|
||||
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.cpu
|
||||
```
|
||||
2. Launch and enter the Faceswap container:
|
||||
|
||||
Create container
|
||||
```bash
|
||||
nvidia-docker run -p 8888:8888 \
|
||||
--hostname faceswap-gpu --name faceswap-gpu \
|
||||
-v /opt/faceswap:/srv \
|
||||
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
||||
-e DISPLAY=unix$DISPLAY \
|
||||
-e AUDIO_GID=`getent group audio | cut -d: -f3` \
|
||||
-e VIDEO_GID=`getent group video | cut -d: -f3` \
|
||||
-e GID=`id -g` \
|
||||
-e UID=`id -u` \
|
||||
deepfakes-gpu
|
||||
a. For the **headless/command line** version of Faceswap run:
|
||||
```
|
||||
docker run --rm -it faceswap-cpu
|
||||
```
|
||||
You can then execute faceswap the standard way:
|
||||
```
|
||||
python faceswap.py --help
|
||||
```
|
||||
b. For the **GUI** version of Faceswap run:
|
||||
```
|
||||
xhost +local: && \
|
||||
docker run --rm -it \
|
||||
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
||||
-e DISPLAY=${DISPLAY} \
|
||||
faceswap-cpu
|
||||
```
|
||||
You can then launch the GUI with
|
||||
```
|
||||
python faceswap.py gui
|
||||
```
|
||||
## Docker Nvidia
|
||||
To build the NVIDIA GPU version of Faceswap, follow these steps:
|
||||
|
||||
1. Nvidia Docker builds need extra resources to provide the Docker container with access to your GPU.
|
||||
|
||||
a. Follow the instructions to install and apply the `Nvidia Container Toolkit` for your distribution from:
|
||||
- https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
b. If Docker is already running, restart it to pick up the changes made by the Nvidia Container Toolkit.
|
||||
|
||||
2. Build the Docker image For faceswap
|
||||
```
|
||||
|
||||
Open a new terminal to interact with the project
|
||||
|
||||
```bash
|
||||
docker exec -it deepfakes-gpu /bin/bash
|
||||
docker build \
|
||||
-t faceswap-gpu \
|
||||
https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.gpu
|
||||
```
|
||||
1. Launch and enter the Faceswap container:
|
||||
|
||||
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt)
|
||||
|
||||
```bash
|
||||
python3.8 /srv/faceswap.py gui
|
||||
```
|
||||
</details>
|
||||
|
||||
## CUDA with Docker on Arch Linux
|
||||
|
||||
<details>
|
||||
<summary>Click to expand!</summary>
|
||||
|
||||
### Install docker
|
||||
|
||||
```bash
|
||||
sudo pacman -S docker
|
||||
```
|
||||
|
||||
The steps are same but Arch linux doesn't use nvidia-docker
|
||||
|
||||
create container
|
||||
|
||||
```bash
|
||||
docker run -p 8888:8888 --gpus all --privileged -v /dev:/dev \
|
||||
--hostname faceswap-gpu --name faceswap-gpu \
|
||||
-v /mnt/hdd2/faceswap:/srv \
|
||||
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
||||
-e DISPLAY=unix$DISPLAY \
|
||||
-e AUDIO_GID=`getent group audio | cut -d: -f3` \
|
||||
-e VIDEO_GID=`getent group video | cut -d: -f3` \
|
||||
-e GID=`id -g` \
|
||||
-e UID=`id -u` \
|
||||
deepfakes-gpu
|
||||
```
|
||||
|
||||
Open a new terminal to interact with the project
|
||||
|
||||
```bash
|
||||
docker exec -it deepfakes-gpu /bin/bash
|
||||
```
|
||||
|
||||
Launch deepfakes gui (Answer 3 for NVIDIA at the prompt)
|
||||
|
||||
**With `gui.tools.py` gui working.**
|
||||
Enable local access to X11 server
|
||||
|
||||
```bash
|
||||
xhost +local:
|
||||
```
|
||||
|
||||
```bash
|
||||
python3.8 /srv/faceswap.py gui
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
## A successful setup log, without docker.
|
||||
```
|
||||
INFO The tool provides tips for installation
|
||||
and installs required python packages
|
||||
INFO Setup in Linux 4.14.39-1-MANJARO
|
||||
INFO Installed Python: 3.7.5 64bit
|
||||
INFO Installed PIP: 10.0.1
|
||||
Enable Docker? [Y/n] n
|
||||
INFO Docker Disabled
|
||||
Enable CUDA? [Y/n]
|
||||
INFO CUDA Enabled
|
||||
INFO CUDA version: 9.1
|
||||
INFO cuDNN version: 7
|
||||
WARNING Tensorflow has no official prebuild for CUDA 9.1 currently.
|
||||
To continue, You have to build your own tensorflow-gpu.
|
||||
Help: https://www.tensorflow.org/install/install_sources
|
||||
Are System Dependencies met? [y/N] y
|
||||
INFO Installing Missing Python Packages...
|
||||
INFO Installing tensorflow-gpu
|
||||
......
|
||||
INFO Installing tqdm
|
||||
INFO Installing matplotlib
|
||||
INFO All python3 dependencies are met.
|
||||
You are good to go.
|
||||
```
|
||||
|
||||
## Run the project
|
||||
a. For the **headless/command line** version of Faceswap run:
|
||||
```
|
||||
docker run --runtime=nvidia --rm -it faceswap-gpu
|
||||
```
|
||||
You can then execute faceswap the standard way:
|
||||
```
|
||||
python faceswap.py --help
|
||||
```
|
||||
b. For the **GUI** version of Faceswap run:
|
||||
```
|
||||
xhost +local: && \
|
||||
docker run --runtime=nvidia --rm -it \
|
||||
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
||||
-e DISPLAY=${DISPLAY} \
|
||||
faceswap-gpu
|
||||
```
|
||||
You can then launch the GUI with
|
||||
```
|
||||
python faceswap.py gui
|
||||
```
|
||||
# Run the project
|
||||
Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options.
|
||||
|
||||
```bash
|
||||
|
|
|
@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath('../'))
|
|||
sys.setrecursionlimit(1500)
|
||||
|
||||
|
||||
MOCK_MODULES = ["plaidml", "pynvx", "ctypes.windll", "comtypes"]
|
||||
MOCK_MODULES = ["pynvx", "ctypes.windll", "comtypes"]
|
||||
for mod_name in MOCK_MODULES:
|
||||
sys.modules[mod_name] = mock.Mock()
|
||||
|
||||
|
|
|
@ -1,25 +1,21 @@
|
|||
# NB Do not install from this requirements file
|
||||
# It is for documentation purposes only
|
||||
|
||||
sphinx==5.0.2
|
||||
sphinx_rtd_theme==1.0.0
|
||||
tqdm==4.64
|
||||
psutil==5.8.0
|
||||
numexpr>=2.8.3
|
||||
numpy>=1.18.0
|
||||
opencv-python>=4.5.5.0
|
||||
pillow==8.3.1
|
||||
scikit-learn>=1.0.2
|
||||
fastcluster>=1.2.4
|
||||
matplotlib==3.5.1
|
||||
numexpr
|
||||
imageio==2.9.0
|
||||
imageio-ffmpeg==0.4.7
|
||||
ffmpy==0.2.3
|
||||
nvidia-ml-py<11.515
|
||||
plaidml==0.7.0
|
||||
sphinx==7.0.1
|
||||
sphinx_rtd_theme==1.2.2
|
||||
tqdm==4.65
|
||||
psutil==5.9.0
|
||||
numexpr>=2.8.4
|
||||
numpy>=1.25.0
|
||||
opencv-python>=4.7.0.0
|
||||
pillow==9.4.0
|
||||
scikit-learn>=1.2.2
|
||||
fastcluster>=1.2.6
|
||||
matplotlib==3.7.1
|
||||
imageio==2.31.1
|
||||
imageio-ffmpeg==0.4.8
|
||||
ffmpy==0.3.0
|
||||
nvidia-ml-py==11.525
|
||||
pytest==7.2.0
|
||||
pytest-mock==3.10.0
|
||||
tensorflow>=2.8.0,<2.9.0
|
||||
tensorflow_probability<0.17
|
||||
typing-extensions>=4.0.0
|
||||
tensorflow>=2.10.0,<2.11.0
|
||||
|
|
|
@ -16,9 +16,8 @@ from lib.config import generate_configs # pylint:disable=wrong-import-position
|
|||
_LANG = gettext.translation("faceswap", localedir="locales", fallback=True)
|
||||
_ = _LANG.gettext
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
raise ValueError("This program requires at least python3.7")
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ValueError("This program requires at least python 3.10")
|
||||
|
||||
_PARSER = cli_args.FullHelpArgumentParser()
|
||||
|
||||
|
|
|
@ -3,22 +3,14 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
from threading import Lock
|
||||
from typing import cast, Dict, Optional, Tuple
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CenteringType = Literal["face", "head", "legacy"]
|
||||
CenteringType = T.Literal["face", "head", "legacy"]
|
||||
|
||||
_MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748],
|
||||
[0.300643, 0.034489], [0.403270, 0.077391], [0.596729, 0.077391],
|
||||
|
@ -65,10 +57,10 @@ _MEAN_FACE_3D = np.array([[4.056931, -11.432347, 1.636229], # 8 chin LL
|
|||
[0.0, -8.601736, 6.097667], # 45 mouth bottom C
|
||||
[0.589441, -8.443925, 6.109526]]) # 44 mouth bottom L
|
||||
|
||||
_EXTRACT_RATIOS = dict(legacy=0.375, face=0.5, head=0.625)
|
||||
_EXTRACT_RATIOS = {"legacy": 0.375, "face": 0.5, "head": 0.625}
|
||||
|
||||
|
||||
def get_matrix_scaling(matrix: np.ndarray) -> Tuple[int, int]:
|
||||
def get_matrix_scaling(matrix: np.ndarray) -> tuple[int, int]:
|
||||
""" Given a matrix, return the cv2 Interpolation method and inverse interpolation method for
|
||||
applying the matrix on an image.
|
||||
|
||||
|
@ -213,6 +205,149 @@ def get_centered_size(source_centering: CenteringType,
|
|||
return retval
|
||||
|
||||
|
||||
class PoseEstimate():
|
||||
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
landmarks: :class:`numpy.ndarry`
|
||||
The original 68 point landmarks aligned to 0.0 - 1.0 range
|
||||
|
||||
References
|
||||
----------
|
||||
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
|
||||
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
|
||||
"""
|
||||
def __init__(self, landmarks: np.ndarray) -> None:
|
||||
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
|
||||
self._xyz_2d: np.ndarray | None = None
|
||||
|
||||
self._camera_matrix = self._get_camera_matrix()
|
||||
self._rotation, self._translation = self._solve_pnp(landmarks)
|
||||
self._offset = self._get_offset()
|
||||
self._pitch_yaw_roll: tuple[float, float, float] = (0, 0, 0)
|
||||
|
||||
@property
|
||||
def xyz_2d(self) -> np.ndarray:
|
||||
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
|
||||
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
|
||||
if self._xyz_2d is None:
|
||||
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
|
||||
[0., 6., -2.3],
|
||||
[0., 0., 3.7]]).astype("float32"),
|
||||
self._rotation,
|
||||
self._translation,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients)[0].squeeze()
|
||||
self._xyz_2d = xyz - self._offset["head"]
|
||||
return self._xyz_2d
|
||||
|
||||
@property
|
||||
def offset(self) -> dict[CenteringType, np.ndarray]:
|
||||
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
|
||||
from the center of the face (between the eyes) or center of the head (middle of skull)
|
||||
rather than the nose area. """
|
||||
return self._offset
|
||||
|
||||
@property
|
||||
def pitch(self) -> float:
|
||||
""" float: The pitch of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[0]
|
||||
|
||||
@property
|
||||
def yaw(self) -> float:
|
||||
""" float: The yaw of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[1]
|
||||
|
||||
@property
|
||||
def roll(self) -> float:
|
||||
""" float: The roll of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[2]
|
||||
|
||||
def _get_pitch_yaw_roll(self) -> None:
|
||||
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
|
||||
proj_matrix = np.zeros((3, 4), dtype="float32")
|
||||
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
|
||||
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
|
||||
self._pitch_yaw_roll = T.cast(tuple[float, float, float], tuple(euler.squeeze()))
|
||||
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _get_camera_matrix(cls) -> np.ndarray:
|
||||
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`numpy.ndarray`
|
||||
An estimated camera matrix
|
||||
"""
|
||||
focal_length = 4
|
||||
camera_matrix = np.array([[focal_length, 0, 0.5],
|
||||
[0, focal_length, 0.5],
|
||||
[0, 0, 1]], dtype="double")
|
||||
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
|
||||
return camera_matrix
|
||||
|
||||
def _solve_pnp(self, landmarks: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Solve the Perspective-n-Point for the given landmarks.
|
||||
|
||||
Takes 2D landmarks in world space and estimates the rotation and translation vectors
|
||||
in 3D space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
landmarks: :class:`numpy.ndarry`
|
||||
The original 68 point landmark co-ordinates relating to the original frame
|
||||
|
||||
Returns
|
||||
-------
|
||||
rotation: :class:`numpy.ndarray`
|
||||
The solved rotation vector
|
||||
translation: :class:`numpy.ndarray`
|
||||
The solved translation vector
|
||||
"""
|
||||
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
|
||||
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
|
||||
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
|
||||
points,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients,
|
||||
flags=cv2.SOLVEPNP_ITERATIVE)
|
||||
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
|
||||
points, rotation, translation)
|
||||
return rotation, translation
|
||||
|
||||
def _get_offset(self) -> dict[CenteringType, np.ndarray]:
|
||||
""" Obtain the offset between the original center of the extracted face to the new center
|
||||
of the head in 2D space.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`numpy.ndarray`
|
||||
The x, y offset of the new center from the old center.
|
||||
"""
|
||||
offset: dict[CenteringType, np.ndarray] = {"legacy": np.array([0.0, 0.0])}
|
||||
points: dict[T.Literal["face", "head"], tuple[float, ...]] = {"head": (0.0, 0.0, -2.3),
|
||||
"face": (0.0, -1.5, 4.2)}
|
||||
|
||||
for key, pnts in points.items():
|
||||
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
|
||||
self._rotation,
|
||||
self._translation,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients)[0].squeeze()
|
||||
logger.trace("center %s: %s", key, center) # type: ignore
|
||||
offset[key] = center - (0.5, 0.5)
|
||||
logger.trace("offset: %s", offset) # type: ignore
|
||||
return offset
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FaceCache: # pylint:disable=too-many-instance-attributes
|
||||
""" Cache for storing items related to a single aligned face.
|
||||
|
@ -251,19 +386,19 @@ class _FaceCache: # pylint:disable=too-many-instance-attributes
|
|||
cropped_slices: dict, optional
|
||||
The slices for an input full head image and output cropped image. Default: `{}`
|
||||
"""
|
||||
pose: Optional["PoseEstimate"] = None
|
||||
original_roi: Optional[np.ndarray] = None
|
||||
landmarks: Optional[np.ndarray] = None
|
||||
landmarks_normalized: Optional[np.ndarray] = None
|
||||
pose: PoseEstimate | None = None
|
||||
original_roi: np.ndarray | None = None
|
||||
landmarks: np.ndarray | None = None
|
||||
landmarks_normalized: np.ndarray | None = None
|
||||
average_distance: float = 0.0
|
||||
relative_eye_mouth_position: float = 0.0
|
||||
adjusted_matrix: Optional[np.ndarray] = None
|
||||
interpolators: Tuple[int, int] = (0, 0)
|
||||
cropped_roi: Dict[CenteringType, np.ndarray] = field(default_factory=dict)
|
||||
cropped_slices: Dict[CenteringType, Dict[Literal["in", "out"],
|
||||
Tuple[slice, slice]]] = field(default_factory=dict)
|
||||
adjusted_matrix: np.ndarray | None = None
|
||||
interpolators: tuple[int, int] = (0, 0)
|
||||
cropped_roi: dict[CenteringType, np.ndarray] = field(default_factory=dict)
|
||||
cropped_slices: dict[CenteringType, dict[T.Literal["in", "out"],
|
||||
tuple[slice, slice]]] = field(default_factory=dict)
|
||||
|
||||
_locks: Dict[str, Lock] = field(default_factory=dict)
|
||||
_locks: dict[str, Lock] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
""" Initialize the locks for the class parameters """
|
||||
|
@ -322,11 +457,11 @@ class AlignedFace():
|
|||
"""
|
||||
def __init__(self,
|
||||
landmarks: np.ndarray,
|
||||
image: Optional[np.ndarray] = None,
|
||||
image: np.ndarray | None = None,
|
||||
centering: CenteringType = "face",
|
||||
size: int = 64,
|
||||
coverage_ratio: float = 1.0,
|
||||
dtype: Optional[str] = None,
|
||||
dtype: str | None = None,
|
||||
is_aligned: bool = False,
|
||||
is_legacy: bool = False) -> None:
|
||||
logger.trace("Initializing: %s (image shape: %s, centering: '%s', " # type: ignore
|
||||
|
@ -340,9 +475,9 @@ class AlignedFace():
|
|||
self._dtype = dtype
|
||||
self._is_aligned = is_aligned
|
||||
self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head"
|
||||
self._matrices = dict(legacy=_umeyama(landmarks[17:], _MEAN_FACE, True)[0:2],
|
||||
face=np.array([]),
|
||||
head=np.array([]))
|
||||
self._matrices = {"legacy": _umeyama(landmarks[17:], _MEAN_FACE, True)[0:2],
|
||||
"face": np.array([]),
|
||||
"head": np.array([])}
|
||||
self._padding = self._padding_from_coverage(size, coverage_ratio)
|
||||
|
||||
self._cache = _FaceCache()
|
||||
|
@ -353,7 +488,7 @@ class AlignedFace():
|
|||
self._face if self._face is None else self._face.shape)
|
||||
|
||||
@property
|
||||
def centering(self) -> Literal["legacy", "head", "face"]:
|
||||
def centering(self) -> T.Literal["legacy", "head", "face"]:
|
||||
""" str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """
|
||||
return self._centering
|
||||
|
||||
|
@ -382,7 +517,7 @@ class AlignedFace():
|
|||
return self._matrices[self._centering]
|
||||
|
||||
@property
|
||||
def pose(self) -> "PoseEstimate":
|
||||
def pose(self) -> PoseEstimate:
|
||||
""" :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """
|
||||
with self._cache.lock("pose"):
|
||||
if self._cache.pose is None:
|
||||
|
@ -405,7 +540,7 @@ class AlignedFace():
|
|||
return self._cache.adjusted_matrix
|
||||
|
||||
@property
|
||||
def face(self) -> Optional[np.ndarray]:
|
||||
def face(self) -> np.ndarray | None:
|
||||
""" :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified
|
||||
:attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided
|
||||
then an the attribute will return ``None``. """
|
||||
|
@ -450,7 +585,7 @@ class AlignedFace():
|
|||
return self._cache.landmarks_normalized
|
||||
|
||||
@property
|
||||
def interpolators(self) -> Tuple[int, int]:
|
||||
def interpolators(self) -> tuple[int, int]:
|
||||
""" tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """
|
||||
with self._cache.lock("interpolators"):
|
||||
if not any(self._cache.interpolators):
|
||||
|
@ -487,7 +622,7 @@ class AlignedFace():
|
|||
return self._cache.relative_eye_mouth_position
|
||||
|
||||
@classmethod
|
||||
def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> Dict[CenteringType, int]:
|
||||
def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> dict[CenteringType, int]:
|
||||
""" Return the image padding for a face from coverage_ratio set against a
|
||||
pre-padded training image.
|
||||
|
||||
|
@ -504,7 +639,7 @@ class AlignedFace():
|
|||
The padding required, in pixels for 'head', 'face' and 'legacy' face types
|
||||
"""
|
||||
retval = {_type: round((size * (coverage_ratio - (1 - _EXTRACT_RATIOS[_type]))) / 2)
|
||||
for _type in get_args(Literal["legacy", "face", "head"])}
|
||||
for _type in T.get_args(T.Literal["legacy", "face", "head"])}
|
||||
logger.trace(retval) # type: ignore
|
||||
return retval
|
||||
|
||||
|
@ -532,7 +667,7 @@ class AlignedFace():
|
|||
invert, points, retval)
|
||||
return retval
|
||||
|
||||
def extract_face(self, image: Optional[np.ndarray]) -> Optional[np.ndarray]:
|
||||
def extract_face(self, image: np.ndarray | None) -> np.ndarray | None:
|
||||
""" Extract the face from a source image and populate :attr:`face`. If an image is not
|
||||
provided then ``None`` is returned.
|
||||
|
||||
|
@ -605,7 +740,7 @@ class AlignedFace():
|
|||
def _get_cropped_slices(self,
|
||||
image_size: int,
|
||||
target_size: int,
|
||||
) -> Dict[Literal["in", "out"], Tuple[slice, slice]]:
|
||||
) -> dict[T.Literal["in", "out"], tuple[slice, slice]]:
|
||||
""" Obtain the slices to turn a full head extract into an alternatively centered extract.
|
||||
|
||||
Parameters
|
||||
|
@ -676,149 +811,6 @@ class AlignedFace():
|
|||
return self._cache.cropped_roi[centering]
|
||||
|
||||
|
||||
class PoseEstimate():
|
||||
""" Estimates pose from a generic 3D head model for the given 2D face landmarks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
landmarks: :class:`numpy.ndarry`
|
||||
The original 68 point landmarks aligned to 0.0 - 1.0 range
|
||||
|
||||
References
|
||||
----------
|
||||
Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/
|
||||
3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp
|
||||
"""
|
||||
def __init__(self, landmarks: np.ndarray) -> None:
|
||||
self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion
|
||||
self._xyz_2d: Optional[np.ndarray] = None
|
||||
|
||||
self._camera_matrix = self._get_camera_matrix()
|
||||
self._rotation, self._translation = self._solve_pnp(landmarks)
|
||||
self._offset = self._get_offset()
|
||||
self._pitch_yaw_roll: Tuple[float, float, float] = (0, 0, 0)
|
||||
|
||||
@property
|
||||
def xyz_2d(self) -> np.ndarray:
|
||||
""" :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a
|
||||
constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """
|
||||
if self._xyz_2d is None:
|
||||
xyz = cv2.projectPoints(np.array([[6., 0., -2.3],
|
||||
[0., 6., -2.3],
|
||||
[0., 0., 3.7]]).astype("float32"),
|
||||
self._rotation,
|
||||
self._translation,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients)[0].squeeze()
|
||||
self._xyz_2d = xyz - self._offset["head"]
|
||||
return self._xyz_2d
|
||||
|
||||
@property
|
||||
def offset(self) -> Dict[CenteringType, np.ndarray]:
|
||||
""" dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a
|
||||
from the center of the face (between the eyes) or center of the head (middle of skull)
|
||||
rather than the nose area. """
|
||||
return self._offset
|
||||
|
||||
@property
|
||||
def pitch(self) -> float:
|
||||
""" float: The pitch of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[0]
|
||||
|
||||
@property
|
||||
def yaw(self) -> float:
|
||||
""" float: The yaw of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[1]
|
||||
|
||||
@property
|
||||
def roll(self) -> float:
|
||||
""" float: The roll of the aligned face in eular angles """
|
||||
if not any(self._pitch_yaw_roll):
|
||||
self._get_pitch_yaw_roll()
|
||||
return self._pitch_yaw_roll[2]
|
||||
|
||||
def _get_pitch_yaw_roll(self) -> None:
|
||||
""" Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """
|
||||
proj_matrix = np.zeros((3, 4), dtype="float32")
|
||||
proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0]
|
||||
euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1]
|
||||
self._pitch_yaw_roll = cast(Tuple[float, float, float], tuple(euler.squeeze()))
|
||||
logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _get_camera_matrix(cls) -> np.ndarray:
|
||||
""" Obtain an estimate of the camera matrix based off the original frame dimensions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`numpy.ndarray`
|
||||
An estimated camera matrix
|
||||
"""
|
||||
focal_length = 4
|
||||
camera_matrix = np.array([[focal_length, 0, 0.5],
|
||||
[0, focal_length, 0.5],
|
||||
[0, 0, 1]], dtype="double")
|
||||
logger.trace("camera_matrix: %s", camera_matrix) # type: ignore
|
||||
return camera_matrix
|
||||
|
||||
def _solve_pnp(self, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
""" Solve the Perspective-n-Point for the given landmarks.
|
||||
|
||||
Takes 2D landmarks in world space and estimates the rotation and translation vectors
|
||||
in 3D space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
landmarks: :class:`numpy.ndarry`
|
||||
The original 68 point landmark co-ordinates relating to the original frame
|
||||
|
||||
Returns
|
||||
-------
|
||||
rotation: :class:`numpy.ndarray`
|
||||
The solved rotation vector
|
||||
translation: :class:`numpy.ndarray`
|
||||
The solved translation vector
|
||||
"""
|
||||
points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34,
|
||||
35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]]
|
||||
_, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D,
|
||||
points,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients,
|
||||
flags=cv2.SOLVEPNP_ITERATIVE)
|
||||
logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore
|
||||
points, rotation, translation)
|
||||
return rotation, translation
|
||||
|
||||
def _get_offset(self) -> Dict[CenteringType, np.ndarray]:
|
||||
""" Obtain the offset between the original center of the extracted face to the new center
|
||||
of the head in 2D space.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`numpy.ndarray`
|
||||
The x, y offset of the new center from the old center.
|
||||
"""
|
||||
offset: Dict[CenteringType, np.ndarray] = dict(legacy=np.array([0.0, 0.0]))
|
||||
points: Dict[Literal["face", "head"], Tuple[float, ...]] = dict(head=(0.0, 0.0, -2.3),
|
||||
face=(0.0, -1.5, 4.2))
|
||||
|
||||
for key, pnts in points.items():
|
||||
center = cv2.projectPoints(np.array([pnts]).astype("float32"),
|
||||
self._rotation,
|
||||
self._translation,
|
||||
self._camera_matrix,
|
||||
self._distortion_coefficients)[0].squeeze()
|
||||
logger.trace("center %s: %s", key, center) # type: ignore
|
||||
offset[key] = center - (0.5, 0.5)
|
||||
logger.trace("offset: %s", offset) # type: ignore
|
||||
return offset
|
||||
|
||||
|
||||
def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray:
|
||||
"""Estimate N-D similarity transformation with or without scaling.
|
||||
|
||||
|
@ -866,24 +858,24 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
|
|||
if np.linalg.det(A) < 0:
|
||||
d[dim - 1] = -1
|
||||
|
||||
T = np.eye(dim + 1, dtype=np.double)
|
||||
retval = np.eye(dim + 1, dtype=np.double)
|
||||
|
||||
U, S, V = np.linalg.svd(A)
|
||||
|
||||
# Eq. (40) and (43).
|
||||
rank = np.linalg.matrix_rank(A)
|
||||
if rank == 0:
|
||||
return np.nan * T
|
||||
return np.nan * retval
|
||||
if rank == dim - 1:
|
||||
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
||||
T[:dim, :dim] = U @ V
|
||||
retval[:dim, :dim] = U @ V
|
||||
else:
|
||||
s = d[dim - 1]
|
||||
d[dim - 1] = -1
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
retval[:dim, :dim] = U @ np.diag(d) @ V
|
||||
d[dim - 1] = s
|
||||
else:
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
retval[:dim, :dim] = U @ np.diag(d) @ V
|
||||
|
||||
if estimate_scale:
|
||||
# Eq. (41) and (42).
|
||||
|
@ -891,7 +883,7 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool)
|
|||
else:
|
||||
scale = 1.0
|
||||
|
||||
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
|
||||
T[:dim, :dim] *= scale
|
||||
retval[:dim, dim] = dst_mean - scale * (retval[:dim, :dim] @ src_mean.T)
|
||||
retval[:dim, :dim] *= scale
|
||||
|
||||
return T
|
||||
return retval
|
||||
|
|
|
@ -1,24 +1,19 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Alignments file functions for reading, writing and manipulating the data stored in a
|
||||
serialized alignments file. """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
from datetime import datetime
|
||||
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lib.serializer import get_serializer, get_serializer_from_filename
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import TypedDict
|
||||
else:
|
||||
from typing import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from .aligned_face import CenteringType
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
@ -35,49 +30,49 @@ _VERSION = 2.3
|
|||
|
||||
|
||||
# TODO Convert these to Dataclasses
|
||||
class MaskAlignmentsFileDict(TypedDict):
|
||||
class MaskAlignmentsFileDict(T.TypedDict):
|
||||
""" Typed Dictionary for storing Masks. """
|
||||
mask: bytes
|
||||
affine_matrix: Union[List[float], np.ndarray]
|
||||
affine_matrix: list[float] | np.ndarray
|
||||
interpolator: int
|
||||
stored_size: int
|
||||
stored_centering: "CenteringType"
|
||||
stored_centering: CenteringType
|
||||
|
||||
|
||||
class PNGHeaderAlignmentsDict(TypedDict):
|
||||
class PNGHeaderAlignmentsDict(T.TypedDict):
|
||||
""" Base Dictionary for storing a single faces' Alignment Information in Alignments files and
|
||||
PNG Headers. """
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
landmarks_xy: Union[List[float], np.ndarray]
|
||||
mask: Dict[str, MaskAlignmentsFileDict]
|
||||
identity: Dict[str, List[float]]
|
||||
landmarks_xy: list[float] | np.ndarray
|
||||
mask: dict[str, MaskAlignmentsFileDict]
|
||||
identity: dict[str, list[float]]
|
||||
|
||||
|
||||
class AlignmentFileDict(PNGHeaderAlignmentsDict):
|
||||
""" Typed Dictionary for storing a single faces' Alignment Information in alignments files. """
|
||||
thumb: Optional[np.ndarray]
|
||||
thumb: np.ndarray | None
|
||||
|
||||
|
||||
class PNGHeaderSourceDict(TypedDict):
|
||||
class PNGHeaderSourceDict(T.TypedDict):
|
||||
""" Dictionary for storing additional meta information in PNG headers """
|
||||
alignments_version: float
|
||||
original_filename: str
|
||||
face_index: int
|
||||
source_filename: str
|
||||
source_is_video: bool
|
||||
source_frame_dims: Optional[Tuple[int, int]]
|
||||
source_frame_dims: tuple[int, int] | None
|
||||
|
||||
|
||||
class AlignmentDict(TypedDict):
|
||||
class AlignmentDict(T.TypedDict):
|
||||
""" Dictionary for holding all of the alignment information within a single alignment file """
|
||||
faces: List[AlignmentFileDict]
|
||||
video_meta: Dict[str, Union[float, int]]
|
||||
faces: list[AlignmentFileDict]
|
||||
video_meta: dict[str, float | int]
|
||||
|
||||
|
||||
class PNGHeaderDict(TypedDict):
|
||||
class PNGHeaderDict(T.TypedDict):
|
||||
""" Dictionary for storing all alignment and meta information in PNG Headers """
|
||||
alignments: PNGHeaderAlignmentsDict
|
||||
source: PNGHeaderSourceDict
|
||||
|
@ -135,7 +130,7 @@ class Alignments():
|
|||
return self._io.file
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, AlignmentDict]:
|
||||
def data(self) -> dict[str, AlignmentDict]:
|
||||
""" dict: The loaded alignments :attr:`file` in dictionary form. """
|
||||
return self._data
|
||||
|
||||
|
@ -146,7 +141,7 @@ class Alignments():
|
|||
return self._io.have_alignments_file
|
||||
|
||||
@property
|
||||
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]:
|
||||
def hashes_to_frame(self) -> dict[str, dict[str, int]]:
|
||||
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
|
||||
that the hash corresponds to.
|
||||
|
||||
|
@ -158,7 +153,7 @@ class Alignments():
|
|||
return self._legacy.hashes_to_frame
|
||||
|
||||
@property
|
||||
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]:
|
||||
def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
|
||||
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
|
||||
corresponds to. The structure of the dictionary is:
|
||||
|
||||
|
@ -170,10 +165,10 @@ class Alignments():
|
|||
return self._legacy.hashes_to_alignment
|
||||
|
||||
@property
|
||||
def mask_summary(self) -> Dict[str, int]:
|
||||
def mask_summary(self) -> dict[str, int]:
|
||||
""" dict: The mask type names stored in the alignments :attr:`data` as key with the number
|
||||
of faces which possess the mask type as value. """
|
||||
masks: Dict[str, int] = {}
|
||||
masks: dict[str, int] = {}
|
||||
for val in self._data.values():
|
||||
for face in val["faces"]:
|
||||
if face.get("mask", None) is None:
|
||||
|
@ -183,21 +178,20 @@ class Alignments():
|
|||
return masks
|
||||
|
||||
@property
|
||||
def video_meta_data(self) -> Dict[str, Optional[Union[List[int], List[float]]]]:
|
||||
def video_meta_data(self) -> dict[str, list[int] | list[float] | None]:
|
||||
""" dict: The frame meta data stored in the alignments file. If data does not exist in the
|
||||
alignments file then ``None`` is returned for each Key """
|
||||
retval: Dict[str, Optional[Union[List[int],
|
||||
List[float]]]] = dict(pts_time=None, keyframes=None)
|
||||
pts_time: List[float] = []
|
||||
keyframes: List[int] = []
|
||||
retval: dict[str, list[int] | list[float] | None] = {"pts_time": None, "keyframes": None}
|
||||
pts_time: list[float] = []
|
||||
keyframes: list[int] = []
|
||||
for idx, key in enumerate(sorted(self.data)):
|
||||
if not self.data[key].get("video_meta", {}):
|
||||
return retval
|
||||
meta = self.data[key]["video_meta"]
|
||||
pts_time.append(cast(float, meta["pts_time"]))
|
||||
pts_time.append(T.cast(float, meta["pts_time"]))
|
||||
if meta["keyframe"]:
|
||||
keyframes.append(idx)
|
||||
retval = dict(pts_time=pts_time, keyframes=keyframes)
|
||||
retval = {"pts_time": pts_time, "keyframes": keyframes}
|
||||
return retval
|
||||
|
||||
@property
|
||||
|
@ -211,7 +205,7 @@ class Alignments():
|
|||
""" float: The alignments file version number. """
|
||||
return self._io.version
|
||||
|
||||
def _load(self) -> Dict[str, AlignmentDict]:
|
||||
def _load(self) -> dict[str, AlignmentDict]:
|
||||
""" Load the alignments data from the serialized alignments :attr:`file`.
|
||||
|
||||
Populates :attr:`_version` with the alignment file's loaded version as well as returning
|
||||
|
@ -238,7 +232,7 @@ class Alignments():
|
|||
"""
|
||||
return self._io.backup()
|
||||
|
||||
def save_video_meta_data(self, pts_time: List[float], keyframes: List[int]) -> None:
|
||||
def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None:
|
||||
""" Save video meta data to the alignments file.
|
||||
|
||||
If the alignments file does not have an entry for every frame (e.g. if Extract Every N
|
||||
|
@ -262,10 +256,10 @@ class Alignments():
|
|||
logger.info("Saving video meta information to Alignments file")
|
||||
|
||||
for idx, pts in enumerate(pts_time):
|
||||
meta: Dict[str, Union[float, int]] = dict(pts_time=pts, keyframe=idx in keyframes)
|
||||
meta: dict[str, float | int] = {"pts_time": pts, "keyframe": idx in keyframes}
|
||||
key = f"{basename}_{idx + 1:06d}.png"
|
||||
if key not in self.data:
|
||||
self.data[key] = dict(video_meta=meta, faces=[])
|
||||
self.data[key] = {"video_meta": meta, "faces": []}
|
||||
else:
|
||||
self.data[key]["video_meta"] = meta
|
||||
|
||||
|
@ -285,8 +279,8 @@ class Alignments():
|
|||
self._io.save()
|
||||
|
||||
@classmethod
|
||||
def _pad_leading_frames(cls, pts_time: List[float], keyframes: List[int]) -> Tuple[List[float],
|
||||
List[int]]:
|
||||
def _pad_leading_frames(cls, pts_time: list[float], keyframes: list[int]) -> tuple[list[float],
|
||||
list[int]]:
|
||||
""" Calculate the number of frames to pad the video by when the first frame is not
|
||||
a key frame.
|
||||
|
||||
|
@ -310,7 +304,7 @@ class Alignments():
|
|||
"""
|
||||
start_pts = pts_time[0]
|
||||
logger.debug("Video not cut on keyframe. Start pts: %s", start_pts)
|
||||
gaps: List[float] = []
|
||||
gaps: list[float] = []
|
||||
prev_time = None
|
||||
for item in pts_time:
|
||||
if prev_time is not None:
|
||||
|
@ -360,7 +354,7 @@ class Alignments():
|
|||
``True`` if the given frame_name exists within the alignments :attr:`data` and has at
|
||||
least 1 face associated with it, otherwise ``False``
|
||||
"""
|
||||
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
|
||||
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
|
||||
retval = bool(frame_data.get("faces", []))
|
||||
logger.trace("'%s': %s", frame_name, retval) # type:ignore
|
||||
return retval
|
||||
|
@ -384,7 +378,7 @@ class Alignments():
|
|||
if not frame_name:
|
||||
retval = False
|
||||
else:
|
||||
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
|
||||
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
|
||||
retval = bool(len(frame_data.get("faces", [])) > 1)
|
||||
logger.trace("'%s': %s", frame_name, retval) # type:ignore
|
||||
return retval
|
||||
|
@ -414,7 +408,7 @@ class Alignments():
|
|||
return retval
|
||||
|
||||
# << DATA >> #
|
||||
def get_faces_in_frame(self, frame_name: str) -> List[AlignmentFileDict]:
|
||||
def get_faces_in_frame(self, frame_name: str) -> list[AlignmentFileDict]:
|
||||
""" Obtain the faces from :attr:`data` associated with a given frame_name.
|
||||
|
||||
Parameters
|
||||
|
@ -429,8 +423,8 @@ class Alignments():
|
|||
The list of face dictionaries that appear within the requested frame_name
|
||||
"""
|
||||
logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore
|
||||
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
|
||||
return frame_data.get("faces", cast(List[AlignmentFileDict], []))
|
||||
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
|
||||
return frame_data.get("faces", T.cast(list[AlignmentFileDict], []))
|
||||
|
||||
def _count_faces_in_frame(self, frame_name: str) -> int:
|
||||
""" Return number of faces that appear within :attr:`data` for the given frame_name.
|
||||
|
@ -446,7 +440,7 @@ class Alignments():
|
|||
int
|
||||
The number of faces that appear in the given frame_name
|
||||
"""
|
||||
frame_data = self._data.get(frame_name, cast(AlignmentDict, {}))
|
||||
frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {}))
|
||||
retval = len(frame_data.get("faces", []))
|
||||
logger.trace(retval) # type:ignore
|
||||
return retval
|
||||
|
@ -497,7 +491,7 @@ class Alignments():
|
|||
"""
|
||||
logger.debug("Adding face to frame_name: '%s'", frame_name)
|
||||
if frame_name not in self._data:
|
||||
self._data[frame_name] = dict(faces=[], video_meta={})
|
||||
self._data[frame_name] = {"faces": [], "video_meta": {}}
|
||||
self._data[frame_name]["faces"].append(face)
|
||||
retval = self._count_faces_in_frame(frame_name) - 1
|
||||
logger.debug("Returning new face index: %s", retval)
|
||||
|
@ -520,7 +514,7 @@ class Alignments():
|
|||
logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name)
|
||||
self._data[frame_name]["faces"][face_index] = face
|
||||
|
||||
def filter_faces(self, filter_dict: Dict[str, List[int]], filter_out: bool = False) -> None:
|
||||
def filter_faces(self, filter_dict: dict[str, list[int]], filter_out: bool = False) -> None:
|
||||
""" Remove faces from :attr:`data` based on a given filter list.
|
||||
|
||||
Parameters
|
||||
|
@ -549,7 +543,7 @@ class Alignments():
|
|||
del frame_data["faces"][face_idx]
|
||||
|
||||
# << GENERATORS >> #
|
||||
def yield_faces(self) -> Generator[Tuple[str, List[AlignmentFileDict], int, str], None, None]:
|
||||
def yield_faces(self) -> Generator[tuple[str, list[AlignmentFileDict], int, str], None, None]:
|
||||
""" Generator to obtain all faces with meta information from :attr:`data`. The results
|
||||
are yielded by frame.
|
||||
|
||||
|
@ -715,7 +709,7 @@ class _IO():
|
|||
logger.info("Updating alignments file to version %s", self._version)
|
||||
self.save()
|
||||
|
||||
def load(self) -> Dict[str, AlignmentDict]:
|
||||
def load(self) -> dict[str, AlignmentDict]:
|
||||
""" Load the alignments data from the serialized alignments :attr:`file`.
|
||||
|
||||
Populates :attr:`_version` with the alignment file's loaded version as well as returning
|
||||
|
@ -732,7 +726,7 @@ class _IO():
|
|||
|
||||
logger.info("Reading alignments from: '%s'", self._file)
|
||||
data = self._serializer.load(self._file)
|
||||
meta = data.get("__meta__", dict(version=1.0))
|
||||
meta = data.get("__meta__", {"version": 1.0})
|
||||
self._version = meta["version"]
|
||||
data = data.get("__data__", data)
|
||||
logger.debug("Loaded alignments")
|
||||
|
@ -743,8 +737,8 @@ class _IO():
|
|||
the location :attr:`file`. """
|
||||
logger.debug("Saving alignments")
|
||||
logger.info("Writing alignments to: '%s'", self._file)
|
||||
data = dict(__meta__=dict(version=self._version),
|
||||
__data__=self._alignments.data)
|
||||
data = {"__meta__": {"version": self._version},
|
||||
"__data__": self._alignments.data}
|
||||
self._serializer.save(self._file, data)
|
||||
logger.debug("Saved alignments")
|
||||
|
||||
|
@ -928,7 +922,7 @@ class _FileStructure(_Updater):
|
|||
for key, val in self._alignments.data.items():
|
||||
if not isinstance(val, list):
|
||||
continue
|
||||
self._alignments.data[key] = dict(faces=val)
|
||||
self._alignments.data[key] = {"faces": val}
|
||||
updated += 1
|
||||
return updated
|
||||
|
||||
|
@ -1078,11 +1072,11 @@ class _Legacy():
|
|||
"""
|
||||
def __init__(self, alignments: Alignments) -> None:
|
||||
self._alignments = alignments
|
||||
self._hashes_to_frame: Dict[str, Dict[str, int]] = {}
|
||||
self._hashes_to_alignment: Dict[str, AlignmentFileDict] = {}
|
||||
self._hashes_to_frame: dict[str, dict[str, int]] = {}
|
||||
self._hashes_to_alignment: dict[str, AlignmentFileDict] = {}
|
||||
|
||||
@property
|
||||
def hashes_to_frame(self) -> Dict[str, Dict[str, int]]:
|
||||
def hashes_to_frame(self) -> dict[str, dict[str, int]]:
|
||||
""" dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame
|
||||
that the hash corresponds to. The structure of the dictionary is:
|
||||
|
||||
|
@ -1105,7 +1099,7 @@ class _Legacy():
|
|||
return self._hashes_to_frame
|
||||
|
||||
@property
|
||||
def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]:
|
||||
def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]:
|
||||
""" dict: The SHA1 hash of the face mapped to the alignment for the face that the hash
|
||||
corresponds to. The structure of the dictionary is:
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
#!/usr/bin python3
|
||||
""" Face and landmarks detection for faceswap.py """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import typing as T
|
||||
|
||||
from hashlib import sha1
|
||||
from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from zlib import compress, decompress
|
||||
|
||||
import cv2
|
||||
|
@ -18,14 +17,10 @@ from .alignments import (Alignments, AlignmentFileDict, MaskAlignmentsFileDict,
|
|||
PNGHeaderAlignmentsDict, PNGHeaderDict, PNGHeaderSourceDict)
|
||||
from . import AlignedFace, get_adjusted_center, get_centered_size
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from .aligned_face import CenteringType
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -85,14 +80,14 @@ class DetectedFace():
|
|||
dict of {**name** (`str`): :class:`Mask`}.
|
||||
"""
|
||||
def __init__(self,
|
||||
image: Optional[np.ndarray] = None,
|
||||
left: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
top: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
landmarks_xy: Optional[np.ndarray] = None,
|
||||
mask: Optional[Dict[str, "Mask"]] = None,
|
||||
filename: Optional[str] = None) -> None:
|
||||
image: np.ndarray | None = None,
|
||||
left: int | None = None,
|
||||
width: int | None = None,
|
||||
top: int | None = None,
|
||||
height: int | None = None,
|
||||
landmarks_xy: np.ndarray | None = None,
|
||||
mask: dict[str, "Mask"] | None = None,
|
||||
filename: str | None = None) -> None:
|
||||
logger.trace("Initializing %s: (image: %s, left: %s, width: %s, top: %s, " # type: ignore
|
||||
"height: %s, landmarks_xy: %s, mask: %s, filename: %s)",
|
||||
self.__class__.__name__,
|
||||
|
@ -104,12 +99,12 @@ class DetectedFace():
|
|||
self.top = top
|
||||
self.height = height
|
||||
self._landmarks_xy = landmarks_xy
|
||||
self._identity: Dict[str, np.ndarray] = {}
|
||||
self.thumbnail: Optional[np.ndarray] = None
|
||||
self._identity: dict[str, np.ndarray] = {}
|
||||
self.thumbnail: np.ndarray | None = None
|
||||
self.mask = {} if mask is None else mask
|
||||
self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None
|
||||
self._training_masks: tuple[bytes, tuple[int, int, int]] | None = None
|
||||
|
||||
self._aligned: Optional[AlignedFace] = None
|
||||
self._aligned: AlignedFace | None = None
|
||||
logger.trace("Initialized %s", self.__class__.__name__) # type: ignore
|
||||
|
||||
@property
|
||||
|
@ -137,7 +132,7 @@ class DetectedFace():
|
|||
return self.top + self.height
|
||||
|
||||
@property
|
||||
def identity(self) -> Dict[str, np.ndarray]:
|
||||
def identity(self) -> dict[str, np.ndarray]:
|
||||
""" dict: Identity mechanism as key, identity embedding as value. """
|
||||
return self._identity
|
||||
|
||||
|
@ -147,7 +142,7 @@ class DetectedFace():
|
|||
affine_matrix: np.ndarray,
|
||||
interpolator: int,
|
||||
storage_size: int = 128,
|
||||
storage_centering: "CenteringType" = "face") -> None:
|
||||
storage_centering: CenteringType = "face") -> None:
|
||||
""" Add a :class:`Mask` to this detected face
|
||||
|
||||
The mask should be the original output from :mod:`plugins.extract.mask`
|
||||
|
@ -209,7 +204,7 @@ class DetectedFace():
|
|||
self._identity[name] = embedding
|
||||
|
||||
def get_landmark_mask(self,
|
||||
area: Literal["eye", "face", "mouth"],
|
||||
area: T.Literal["eye", "face", "mouth"],
|
||||
blur_kernel: int,
|
||||
dilation: int) -> np.ndarray:
|
||||
""" Add a :class:`LandmarksMask` to this detected face
|
||||
|
@ -235,7 +230,7 @@ class DetectedFace():
|
|||
"""
|
||||
# TODO Face mask generation from landmarks
|
||||
logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore
|
||||
areas = dict(mouth=[slice(48, 60)], eye=[slice(36, 42), slice(42, 48)])
|
||||
areas = {"mouth": [slice(48, 60)], "eye": [slice(36, 42), slice(42, 48)]}
|
||||
points = [self.aligned.landmarks[zone]
|
||||
for zone in areas[area]]
|
||||
|
||||
|
@ -250,7 +245,7 @@ class DetectedFace():
|
|||
return lmmask.mask
|
||||
|
||||
def store_training_masks(self,
|
||||
masks: List[Optional[np.ndarray]],
|
||||
masks: list[np.ndarray | None],
|
||||
delete_masks: bool = False) -> None:
|
||||
""" Concatenate and compress the given training masks and store for retrieval.
|
||||
|
||||
|
@ -273,7 +268,7 @@ class DetectedFace():
|
|||
combined = np.concatenate(valid, axis=-1)
|
||||
self._training_masks = (compress(combined), combined.shape)
|
||||
|
||||
def get_training_masks(self) -> Optional[np.ndarray]:
|
||||
def get_training_masks(self) -> np.ndarray | None:
|
||||
""" Obtain the decompressed combined training masks.
|
||||
|
||||
Returns
|
||||
|
@ -312,7 +307,7 @@ class DetectedFace():
|
|||
return alignment
|
||||
|
||||
def from_alignment(self, alignment: AlignmentFileDict,
|
||||
image: Optional[np.ndarray] = None, with_thumb: bool = False) -> None:
|
||||
image: np.ndarray | None = None, with_thumb: bool = False) -> None:
|
||||
""" Set the attributes of this class from an alignments file and optionally load the face
|
||||
into the ``image`` attribute.
|
||||
|
||||
|
@ -342,7 +337,7 @@ class DetectedFace():
|
|||
landmarks = alignment["landmarks_xy"]
|
||||
if not isinstance(landmarks, np.ndarray):
|
||||
landmarks = np.array(landmarks, dtype="float32")
|
||||
self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32")
|
||||
self._identity = {T.cast(T.Literal["vggface2"], k): np.array(v, dtype="float32")
|
||||
for k, v in alignment.get("identity", {}).items()}
|
||||
self._landmarks_xy = landmarks.copy()
|
||||
|
||||
|
@ -403,7 +398,7 @@ class DetectedFace():
|
|||
self._identity = {}
|
||||
for key, val in alignment.get("identity", {}).items():
|
||||
assert key in ["vggface2"]
|
||||
self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32")
|
||||
self._identity[T.cast(T.Literal["vggface2"], key)] = np.array(val, dtype="float32")
|
||||
logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore
|
||||
" height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width,
|
||||
self.top, self.height, self.landmarks_xy, self.mask,
|
||||
|
@ -417,10 +412,10 @@ class DetectedFace():
|
|||
|
||||
# <<< Aligned Face methods and properties >>> #
|
||||
def load_aligned(self,
|
||||
image: Optional[np.ndarray],
|
||||
image: np.ndarray | None,
|
||||
size: int = 256,
|
||||
dtype: Optional[str] = None,
|
||||
centering: "CenteringType" = "head",
|
||||
dtype: str | None = None,
|
||||
centering: CenteringType = "head",
|
||||
coverage_ratio: float = 1.0,
|
||||
force: bool = False,
|
||||
is_aligned: bool = False,
|
||||
|
@ -507,22 +502,22 @@ class Mask():
|
|||
"""
|
||||
def __init__(self,
|
||||
storage_size: int = 128,
|
||||
storage_centering: "CenteringType" = "face") -> None:
|
||||
storage_centering: CenteringType = "face") -> None:
|
||||
logger.trace("Initializing: %s (storage_size: %s, storage_centering: %s)", # type: ignore
|
||||
self.__class__.__name__, storage_size, storage_centering)
|
||||
self.stored_size = storage_size
|
||||
self.stored_centering = storage_centering
|
||||
|
||||
self._mask: Optional[bytes] = None
|
||||
self._affine_matrix: Optional[np.ndarray] = None
|
||||
self._interpolator: Optional[int] = None
|
||||
self._mask: bytes | None = None
|
||||
self._affine_matrix: np.ndarray | None = None
|
||||
self._interpolator: int | None = None
|
||||
|
||||
self._blur_type: Optional[Literal["gaussian", "normalized"]] = None
|
||||
self._blur_type: T.Literal["gaussian", "normalized"] | None = None
|
||||
self._blur_passes: int = 0
|
||||
self._blur_kernel: Union[float, int] = 0
|
||||
self._blur_kernel: float | int = 0
|
||||
self._threshold = 0.0
|
||||
self._sub_crop_size = 0
|
||||
self._sub_crop_slices: Dict[Literal["in", "out"], List[slice]] = {}
|
||||
self._sub_crop_slices: dict[T.Literal["in", "out"], list[slice]] = {}
|
||||
|
||||
self.set_blur_and_threshold()
|
||||
logger.trace("Initialized: %s", self.__class__.__name__) # type: ignore
|
||||
|
@ -648,7 +643,7 @@ class Mask():
|
|||
|
||||
def set_blur_and_threshold(self,
|
||||
blur_kernel: int = 0,
|
||||
blur_type: Optional[Literal["gaussian", "normalized"]] = "gaussian",
|
||||
blur_type: T.Literal["gaussian", "normalized"] | None = "gaussian",
|
||||
blur_passes: int = 1,
|
||||
threshold: int = 0) -> None:
|
||||
""" Set the internal blur kernel and threshold amount for returned masks
|
||||
|
@ -679,7 +674,7 @@ class Mask():
|
|||
def set_sub_crop(self,
|
||||
source_offset: np.ndarray,
|
||||
target_offset: np.ndarray,
|
||||
centering: "CenteringType",
|
||||
centering: CenteringType,
|
||||
coverage_ratio: float = 1.0) -> None:
|
||||
""" Set the internal crop area of the mask to be returned.
|
||||
|
||||
|
@ -831,9 +826,9 @@ class LandmarksMask(Mask):
|
|||
The amount of dilation to apply to the mask. `0` for none. Default: `0`
|
||||
"""
|
||||
def __init__(self,
|
||||
points: List[np.ndarray],
|
||||
points: list[np.ndarray],
|
||||
storage_size: int = 128,
|
||||
storage_centering: "CenteringType" = "face",
|
||||
storage_centering: CenteringType = "face",
|
||||
dilation: int = 0) -> None:
|
||||
super().__init__(storage_size=storage_size, storage_centering=storage_centering)
|
||||
self._points = points
|
||||
|
@ -907,9 +902,9 @@ class BlurMask(): # pylint:disable=too-few-public-methods
|
|||
(128, 128, 1)
|
||||
"""
|
||||
def __init__(self,
|
||||
blur_type: Literal["gaussian", "normalized"],
|
||||
blur_type: T.Literal["gaussian", "normalized"],
|
||||
mask: np.ndarray,
|
||||
kernel: Union[int, float],
|
||||
kernel: int | float,
|
||||
is_ratio: bool = False,
|
||||
passes: int = 1) -> None:
|
||||
logger.trace("Initializing %s: (blur_type: '%s', mask_shape: %s, " # type: ignore
|
||||
|
@ -943,33 +938,30 @@ class BlurMask(): # pylint:disable=too-few-public-methods
|
|||
def _multipass_factor(self) -> float:
|
||||
""" For multiple passes the kernel must be scaled down. This value is
|
||||
different for box filter and gaussian """
|
||||
factor = dict(gaussian=0.8, normalized=0.5)
|
||||
factor = {"gaussian": 0.8, "normalized": 0.5}
|
||||
return factor[self._blur_type]
|
||||
|
||||
@property
|
||||
def _sigma(self) -> Literal[0]:
|
||||
def _sigma(self) -> T.Literal[0]:
|
||||
""" int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """
|
||||
return 0
|
||||
|
||||
@property
|
||||
def _func_mapping(self) -> Dict[Literal["gaussian", "normalized"], Callable]:
|
||||
def _func_mapping(self) -> dict[T.Literal["gaussian", "normalized"], Callable]:
|
||||
""" dict: :attr:`_blur_type` mapped to cv2 Function name. """
|
||||
return dict(gaussian=cv2.GaussianBlur, # pylint: disable = no-member
|
||||
normalized=cv2.blur) # pylint: disable = no-member
|
||||
return {"gaussian": cv2.GaussianBlur, "normalized": cv2.blur}
|
||||
|
||||
@property
|
||||
def _kwarg_requirements(self) -> Dict[Literal["gaussian", "normalized"], List[str]]:
|
||||
def _kwarg_requirements(self) -> dict[T.Literal["gaussian", "normalized"], list[str]]:
|
||||
""" dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """
|
||||
return dict(gaussian=["ksize", "sigmaX"],
|
||||
normalized=["ksize"])
|
||||
return {"gaussian": ['ksize', 'sigmaX'], "normalized": ['ksize']}
|
||||
|
||||
@property
|
||||
def _kwarg_mapping(self) -> Dict[str, Union[int, Tuple[int, int]]]:
|
||||
def _kwarg_mapping(self) -> dict[str, int | tuple[int, int]]:
|
||||
""" dict: cv2 function keyword arguments mapped to their parameters. """
|
||||
return dict(ksize=self._kernel_size,
|
||||
sigmaX=self._sigma)
|
||||
return {"ksize": self._kernel_size, "sigmaX": self._sigma}
|
||||
|
||||
def _get_kernel_size(self, kernel: Union[int, float], is_ratio: bool) -> int:
|
||||
def _get_kernel_size(self, kernel: int | float, is_ratio: bool) -> int:
|
||||
""" Set the kernel size to absolute value.
|
||||
|
||||
If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and
|
||||
|
@ -999,7 +991,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
|
|||
return kernel_size
|
||||
|
||||
@staticmethod
|
||||
def _get_kernel_tuple(kernel_size: int) -> Tuple[int, int]:
|
||||
def _get_kernel_tuple(kernel_size: int) -> tuple[int, int]:
|
||||
""" Make sure kernel_size is odd and return it as a tuple.
|
||||
|
||||
Parameters
|
||||
|
@ -1017,7 +1009,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods
|
|||
logger.trace(retval) # type: ignore
|
||||
return retval
|
||||
|
||||
def _get_kwargs(self) -> Dict[str, Union[int, Tuple[int, int]]]:
|
||||
def _get_kwargs(self) -> dict[str, int | tuple[int, int]]:
|
||||
""" dict: the valid keyword arguments for the requested :attr:`_blur_type` """
|
||||
retval = {kword: self._kwarg_mapping[kword]
|
||||
for kword in self._kwarg_requirements[self._blur_type]}
|
||||
|
@ -1025,11 +1017,11 @@ class BlurMask(): # pylint:disable=too-few-public-methods
|
|||
return retval
|
||||
|
||||
|
||||
_HASHES_SEEN: Dict[str, Dict[str, int]] = {}
|
||||
_HASHES_SEEN: dict[str, dict[str, int]] = {}
|
||||
|
||||
|
||||
def update_legacy_png_header(filename: str, alignments: Alignments
|
||||
) -> Optional[PNGHeaderDict]:
|
||||
) -> PNGHeaderDict | None:
|
||||
""" Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for
|
||||
the face in the png exif header for the given filename with the given alignment data.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ as well as adding a mechanism for indicating to the GUI how specific options sho
|
|||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import typing as T
|
||||
|
||||
|
||||
# << FILE HANDLING >>
|
||||
|
@ -69,7 +69,7 @@ class FileFullPaths(_FullPaths):
|
|||
>>> filetypes="video))"
|
||||
"""
|
||||
# pylint: disable=too-few-public-methods
|
||||
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
|
||||
def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.filetypes = filetypes
|
||||
|
||||
|
@ -111,7 +111,7 @@ class FilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods
|
|||
>>> filetypes="image",
|
||||
>>> nargs="+"))
|
||||
"""
|
||||
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
|
||||
def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None:
|
||||
if kwargs.get("nargs", None) is None:
|
||||
opt = kwargs["option_strings"]
|
||||
raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}")
|
||||
|
@ -250,8 +250,8 @@ class ContextFullPaths(FileFullPaths):
|
|||
# pylint: disable=too-few-public-methods, too-many-arguments
|
||||
def __init__(self,
|
||||
*args,
|
||||
filetypes: Optional[str] = None,
|
||||
action_option: Optional[str] = None,
|
||||
filetypes: str | None = None,
|
||||
action_option: str | None = None,
|
||||
**kwargs) -> None:
|
||||
opt = kwargs["option_strings"]
|
||||
if kwargs.get("nargs", None) is not None:
|
||||
|
@ -263,7 +263,7 @@ class ContextFullPaths(FileFullPaths):
|
|||
super().__init__(*args, filetypes=filetypes, **kwargs)
|
||||
self.action_option = action_option
|
||||
|
||||
def _get_kwargs(self) -> List[Tuple[str, Any]]:
|
||||
def _get_kwargs(self) -> list[tuple[str, T.Any]]:
|
||||
names = ["option_strings",
|
||||
"dest",
|
||||
"nargs",
|
||||
|
@ -382,8 +382,8 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
|
||||
rounding: Optional[int] = None,
|
||||
min_max: tuple[int, int] | tuple[float, float] | None = None,
|
||||
rounding: int | None = None,
|
||||
**kwargs) -> None:
|
||||
opt = kwargs["option_strings"]
|
||||
if kwargs.get("nargs", None) is not None:
|
||||
|
@ -401,7 +401,7 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
|
|||
self.min_max = min_max
|
||||
self.rounding = rounding
|
||||
|
||||
def _get_kwargs(self) -> List[Tuple[str, Any]]:
|
||||
def _get_kwargs(self) -> list[tuple[str, T.Any]]:
|
||||
names = ["option_strings",
|
||||
"dest",
|
||||
"nargs",
|
||||
|
|
|
@ -8,8 +8,7 @@ import logging
|
|||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from typing import Any, Dict, List, NoReturn, Optional
|
||||
import typing as T
|
||||
|
||||
from lib.utils import get_backend
|
||||
from lib.gpu_stats import GPUStats
|
||||
|
@ -30,7 +29,7 @@ _ = _LANG.gettext
|
|||
|
||||
class FullHelpArgumentParser(argparse.ArgumentParser):
|
||||
""" Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """
|
||||
def error(self, message: str) -> NoReturn:
|
||||
def error(self, message: str) -> T.NoReturn:
|
||||
self.print_help(sys.stderr)
|
||||
self.exit(2, f"{self.prog}: error: {message}\n")
|
||||
|
||||
|
@ -51,11 +50,11 @@ class SmartFormatter(argparse.HelpFormatter):
|
|||
prog: str,
|
||||
indent_increment: int = 2,
|
||||
max_help_position: int = 24,
|
||||
width: Optional[int] = None) -> None:
|
||||
width: int | None = None) -> None:
|
||||
super().__init__(prog, indent_increment, max_help_position, width)
|
||||
self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII)
|
||||
|
||||
def _split_lines(self, text: str, width: int) -> List[str]:
|
||||
def _split_lines(self, text: str, width: int) -> list[str]:
|
||||
""" Split the given text by the given display width.
|
||||
|
||||
If the text is not prefixed with "R|" then the standard
|
||||
|
@ -138,7 +137,7 @@ class FaceSwapArgs():
|
|||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_argument_list() -> List[Dict[str, Any]]:
|
||||
def get_argument_list() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list for the current command.
|
||||
|
||||
The argument list should be a list of dictionaries pertaining to each option for a command.
|
||||
|
@ -152,11 +151,11 @@ class FaceSwapArgs():
|
|||
list
|
||||
The list of command line options for the given command
|
||||
"""
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
return argument_list
|
||||
|
||||
@staticmethod
|
||||
def get_optional_arguments() -> List[Dict[str, Any]]:
|
||||
def get_optional_arguments() -> list[dict[str, T.Any]]:
|
||||
""" Returns the optional argument list for the current command.
|
||||
|
||||
The optional arguments list is not always required, but is used when there are shared
|
||||
|
@ -167,11 +166,11 @@ class FaceSwapArgs():
|
|||
list
|
||||
The list of optional command line options for the given command
|
||||
"""
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
return argument_list
|
||||
|
||||
@staticmethod
|
||||
def _get_global_arguments() -> List[Dict[str, Any]]:
|
||||
def _get_global_arguments() -> list[dict[str, T.Any]]:
|
||||
""" Returns the global Arguments list that are required for ALL commands in Faceswap.
|
||||
|
||||
This method should NOT be overridden.
|
||||
|
@ -181,7 +180,7 @@ class FaceSwapArgs():
|
|||
list
|
||||
The list of global command line options for all Faceswap commands.
|
||||
"""
|
||||
global_args: List[Dict[str, Any]] = []
|
||||
global_args: list[dict[str, T.Any]] = []
|
||||
if _GPUS:
|
||||
global_args.append(dict(
|
||||
opts=("-X", "--exclude-gpus"),
|
||||
|
@ -302,7 +301,7 @@ class ExtractConvertArgs(FaceSwapArgs):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_argument_list() -> List[Dict[str, Any]]:
|
||||
def get_argument_list() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list for shared Extract and Convert arguments.
|
||||
|
||||
Returns
|
||||
|
@ -310,7 +309,7 @@ class ExtractConvertArgs(FaceSwapArgs):
|
|||
list
|
||||
The list of command line options for the given Extract and Convert
|
||||
"""
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
argument_list.append(dict(
|
||||
opts=("-i", "--input-dir"),
|
||||
action=DirOrFileFullPaths,
|
||||
|
@ -362,7 +361,7 @@ class ExtractArgs(ExtractConvertArgs):
|
|||
"Extraction plugins can be configured in the 'Settings' Menu")
|
||||
|
||||
@staticmethod
|
||||
def get_optional_arguments() -> List[Dict[str, Any]]:
|
||||
def get_optional_arguments() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list unique to the Extract command.
|
||||
|
||||
Returns
|
||||
|
@ -377,7 +376,7 @@ class ExtractArgs(ExtractConvertArgs):
|
|||
default_detector = "s3fd"
|
||||
default_aligner = "fan"
|
||||
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
argument_list.append(dict(
|
||||
opts=("-b", "--batch-mode"),
|
||||
action="store_true",
|
||||
|
@ -658,7 +657,7 @@ class ConvertArgs(ExtractConvertArgs):
|
|||
"Conversion plugins can be configured in the 'Settings' Menu")
|
||||
|
||||
@staticmethod
|
||||
def get_optional_arguments() -> List[Dict[str, Any]]:
|
||||
def get_optional_arguments() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list unique to the Convert command.
|
||||
|
||||
Returns
|
||||
|
@ -667,7 +666,7 @@ class ConvertArgs(ExtractConvertArgs):
|
|||
The list of optional command line options for the Convert command
|
||||
"""
|
||||
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
argument_list.append(dict(
|
||||
opts=("-ref", "--reference-video"),
|
||||
action=FileFullPaths,
|
||||
|
@ -915,7 +914,7 @@ class TrainArgs(FaceSwapArgs):
|
|||
"Model plugins can be configured in the 'Settings' Menu")
|
||||
|
||||
@staticmethod
|
||||
def get_argument_list() -> List[Dict[str, Any]]:
|
||||
def get_argument_list() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list for Train arguments.
|
||||
|
||||
Returns
|
||||
|
@ -923,7 +922,7 @@ class TrainArgs(FaceSwapArgs):
|
|||
list
|
||||
The list of command line options for training
|
||||
"""
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
argument_list.append(dict(
|
||||
opts=("-A", "--input-A"),
|
||||
action=DirFullPaths,
|
||||
|
@ -1180,7 +1179,7 @@ class GuiArgs(FaceSwapArgs):
|
|||
""" Creates the command line arguments for the GUI. """
|
||||
|
||||
@staticmethod
|
||||
def get_argument_list() -> List[Dict[str, Any]]:
|
||||
def get_argument_list() -> list[dict[str, T.Any]]:
|
||||
""" Returns the argument list for GUI arguments.
|
||||
|
||||
Returns
|
||||
|
@ -1188,7 +1187,7 @@ class GuiArgs(FaceSwapArgs):
|
|||
list
|
||||
The list of command line options for the GUI
|
||||
"""
|
||||
argument_list: List[Dict[str, Any]] = []
|
||||
argument_list: list[dict[str, T.Any]] = []
|
||||
argument_list.append(dict(
|
||||
opts=("-d", "--debug"),
|
||||
action="store_true",
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Launches the correct script with the given Command Line Arguments """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
from lib.gpu_stats import set_exclude_devices, GPUStats
|
||||
from lib.logger import crash_log, log_setup
|
||||
from lib.utils import (FaceswapError, get_backend, get_tf_version,
|
||||
safe_shutdown, set_backend, set_system_verbosity)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
import argparse
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
@ -99,8 +101,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
FaceswapError
|
||||
If Tensorflow is not found, or is not between versions 2.4 and 2.9
|
||||
"""
|
||||
directml_ver = rocm_ver = (2, 10)
|
||||
min_ver = (2, 7)
|
||||
min_ver = (2, 10)
|
||||
max_ver = (2, 10)
|
||||
try:
|
||||
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
|
||||
|
@ -120,7 +121,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
self._handle_import_error(msg)
|
||||
|
||||
tf_ver = get_tf_version()
|
||||
backend = get_backend()
|
||||
if tf_ver < min_ver:
|
||||
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
|
||||
f"{tf_ver} installed. Please upgrade Tensorflow.")
|
||||
|
@ -129,14 +129,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
|
||||
f"{tf_ver} installed. Please downgrade Tensorflow.")
|
||||
self._handle_import_error(msg)
|
||||
if backend == "directml" and tf_ver != directml_ver:
|
||||
msg = (f"The supported Tensorflow version for DirectML cards is {directml_ver} but "
|
||||
f"you have version {tf_ver} installed. Please install the correct version.")
|
||||
self._handle_import_error(msg)
|
||||
if backend == "rocm" and tf_ver != rocm_ver:
|
||||
msg = (f"The supported Tensorflow version for ROCm cards is {rocm_ver} but "
|
||||
f"you have version {tf_ver} installed. Please install the correct version.")
|
||||
self._handle_import_error(msg)
|
||||
logger.debug("Installed Tensorflow Version: %s", tf_ver)
|
||||
|
||||
@classmethod
|
||||
|
@ -209,7 +201,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
"See https://support.apple.com/en-gb/HT201341")
|
||||
raise FaceswapError("No display detected. GUI mode has been disabled.")
|
||||
|
||||
def execute_script(self, arguments: "argparse.Namespace") -> None:
|
||||
def execute_script(self, arguments: argparse.Namespace) -> None:
|
||||
""" Performs final set up and launches the requested :attr:`_command` with the given
|
||||
command line arguments.
|
||||
|
||||
|
@ -250,7 +242,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods
|
|||
finally:
|
||||
safe_shutdown(got_error=not success)
|
||||
|
||||
def _configure_backend(self, arguments: "argparse.Namespace") -> None:
|
||||
def _configure_backend(self, arguments: argparse.Namespace) -> None:
|
||||
""" Configure the backend.
|
||||
|
||||
Exclude any GPUs for use by Faceswap when requested.
|
||||
|
|
|
@ -13,7 +13,6 @@ from collections import OrderedDict
|
|||
from configparser import ConfigParser
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from lib.utils import full_path_split
|
||||
|
||||
|
@ -21,16 +20,11 @@ from lib.utils import full_path_split
|
|||
_LANG = gettext.translation("lib.config", localedir="locales", fallback=True)
|
||||
_ = _LANG.gettext
|
||||
|
||||
# Can't type OrderedDict fully on Python 3.8 or lower
|
||||
if sys.version_info < (3, 9):
|
||||
OrderedDictSectionType = OrderedDict
|
||||
OrderedDictItemType = OrderedDict
|
||||
else:
|
||||
OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
|
||||
OrderedDictItemType = OrderedDict[str, "ConfigItem"]
|
||||
OrderedDictSectionType = OrderedDict[str, "ConfigSection"]
|
||||
OrderedDictItemType = OrderedDict[str, "ConfigItem"]
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
ConfigValueType = Union[bool, int, float, List[str], str, None]
|
||||
ConfigValueType = bool | int | float | list[str] | str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -60,11 +54,11 @@ class ConfigItem:
|
|||
helptext: str
|
||||
datatype: type
|
||||
rounding: int
|
||||
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]]
|
||||
choices: Union[str, List[str]]
|
||||
min_max: tuple[int, int] | tuple[float, float] | None
|
||||
choices: str | list[str]
|
||||
gui_radio: bool
|
||||
fixed: bool
|
||||
group: Optional[str]
|
||||
group: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -84,7 +78,7 @@ class ConfigSection:
|
|||
|
||||
class FaceswapConfig():
|
||||
""" Config Items """
|
||||
def __init__(self, section: Optional[str], configfile: Optional[str] = None) -> None:
|
||||
def __init__(self, section: str | None, configfile: str | None = None) -> None:
|
||||
""" Init Configuration
|
||||
|
||||
Parameters
|
||||
|
@ -106,11 +100,11 @@ class FaceswapConfig():
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def changeable_items(self) -> Dict[str, ConfigValueType]:
|
||||
def changeable_items(self) -> dict[str, ConfigValueType]:
|
||||
""" Training only.
|
||||
Return a dict of config items with their set values for items
|
||||
that can be altered after the model has been created """
|
||||
retval: Dict[str, ConfigValueType] = {}
|
||||
retval: dict[str, ConfigValueType] = {}
|
||||
sections = [sect for sect in self.config.sections() if sect.startswith("global")]
|
||||
all_sections = sections if self.section is None else sections + [self.section]
|
||||
for sect in all_sections:
|
||||
|
@ -189,10 +183,10 @@ class FaceswapConfig():
|
|||
logger.debug("Added defaults: %s", section)
|
||||
|
||||
@property
|
||||
def config_dict(self) -> Dict[str, ConfigValueType]:
|
||||
def config_dict(self) -> dict[str, ConfigValueType]:
|
||||
""" dict: Collate global options and requested section into a dictionary with the correct
|
||||
data types """
|
||||
conf: Dict[str, ConfigValueType] = {}
|
||||
conf: dict[str, ConfigValueType] = {}
|
||||
sections = [sect for sect in self.config.sections() if sect.startswith("global")]
|
||||
if self.section is not None:
|
||||
sections.append(self.section)
|
||||
|
@ -240,7 +234,7 @@ class FaceswapConfig():
|
|||
logger.debug("Returning item: (type: %s, value: %s)", datatype, retval)
|
||||
return retval
|
||||
|
||||
def _parse_list(self, section: str, option: str) -> List[str]:
|
||||
def _parse_list(self, section: str, option: str) -> list[str]:
|
||||
""" Parse options that are stored as lists in the config file. These can be space or
|
||||
comma-separated items in the config file. They will be returned as a list of strings,
|
||||
regardless of what the final data type should be, so conversion from strings to other
|
||||
|
@ -268,7 +262,7 @@ class FaceswapConfig():
|
|||
raw_option, retval, section, option)
|
||||
return retval
|
||||
|
||||
def _get_config_file(self, configfile: Optional[str]) -> str:
|
||||
def _get_config_file(self, configfile: str | None) -> str:
|
||||
""" Return the config file from the calling folder or the provided file
|
||||
|
||||
Parameters
|
||||
|
@ -309,17 +303,17 @@ class FaceswapConfig():
|
|||
self.defaults[title] = ConfigSection(helptext=info, items=OrderedDict())
|
||||
|
||||
def add_item(self,
|
||||
section: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
section: str | None = None,
|
||||
title: str | None = None,
|
||||
datatype: type = str,
|
||||
default: ConfigValueType = None,
|
||||
info: Optional[str] = None,
|
||||
rounding: Optional[int] = None,
|
||||
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
|
||||
choices: Optional[Union[str, List[str]]] = None,
|
||||
info: str | None = None,
|
||||
rounding: int | None = None,
|
||||
min_max: tuple[int, int] | tuple[float, float] | None = None,
|
||||
choices: str | list[str] | None = None,
|
||||
gui_radio: bool = False,
|
||||
fixed: bool = True,
|
||||
group: Optional[str] = None) -> None:
|
||||
group: str | None = None) -> None:
|
||||
""" Add a default item to a config section
|
||||
|
||||
For int or float values, rounding and min_max must be set
|
||||
|
@ -382,10 +376,10 @@ class FaceswapConfig():
|
|||
@classmethod
|
||||
def _expand_helptext(cls,
|
||||
helptext: str,
|
||||
choices: Union[str, List[str]],
|
||||
choices: str | list[str],
|
||||
default: ConfigValueType,
|
||||
datatype: type,
|
||||
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]],
|
||||
min_max: tuple[int, int] | tuple[float, float] | None,
|
||||
fixed: bool) -> str:
|
||||
""" Add extra helptext info from parameters """
|
||||
helptext += "\n"
|
||||
|
@ -437,7 +431,7 @@ class FaceswapConfig():
|
|||
def insert_config_section(self,
|
||||
section: str,
|
||||
helptext: str,
|
||||
config: Optional[ConfigParser] = None) -> None:
|
||||
config: ConfigParser | None = None) -> None:
|
||||
""" Insert a section into the config
|
||||
|
||||
Parameters
|
||||
|
@ -464,7 +458,7 @@ class FaceswapConfig():
|
|||
item: str,
|
||||
default: ConfigValueType,
|
||||
option: ConfigItem,
|
||||
config: Optional[ConfigParser] = None) -> None:
|
||||
config: ConfigParser | None = None) -> None:
|
||||
""" Insert an item into a config section
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -1,23 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Converter for Faceswap """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, cast, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from plugins.plugin_loader import PluginLoader
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from collections.abc import Callable
|
||||
from lib.align.aligned_face import AlignedFace, CenteringType
|
||||
from lib.align.detected_face import DetectedFace
|
||||
from lib.config import FaceswapConfig
|
||||
|
@ -46,10 +41,10 @@ class Adjustments:
|
|||
sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional
|
||||
The selected mask processing plugin. Default: `None`
|
||||
"""
|
||||
color: Optional["ColorAdjust"] = None
|
||||
mask: Optional["MaskAdjust"] = None
|
||||
seamless: Optional["SeamlessAdjust"] = None
|
||||
sharpening: Optional["ScalingAdjust"] = None
|
||||
color: ColorAdjust | None = None
|
||||
mask: MaskAdjust | None = None
|
||||
seamless: SeamlessAdjust | None = None
|
||||
sharpening: ScalingAdjust | None = None
|
||||
|
||||
|
||||
class Converter():
|
||||
|
@ -81,11 +76,11 @@ class Converter():
|
|||
def __init__(self,
|
||||
output_size: int,
|
||||
coverage_ratio: float,
|
||||
centering: "CenteringType",
|
||||
centering: CenteringType,
|
||||
draw_transparent: bool,
|
||||
pre_encode: Optional[Callable[[np.ndarray], List[bytes]]],
|
||||
arguments: "Namespace",
|
||||
configfile: Optional[str] = None) -> None:
|
||||
pre_encode: Callable[[np.ndarray], list[bytes]] | None,
|
||||
arguments: Namespace,
|
||||
configfile: str | None = None) -> None:
|
||||
logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, "
|
||||
"draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)",
|
||||
self.__class__.__name__, output_size, coverage_ratio, centering,
|
||||
|
@ -105,12 +100,12 @@ class Converter():
|
|||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def cli_arguments(self) -> "Namespace":
|
||||
def cli_arguments(self) -> Namespace:
|
||||
""":class:`argparse.Namespace`: The command line arguments passed to the convert
|
||||
process """
|
||||
return self._args
|
||||
|
||||
def reinitialize(self, config: "FaceswapConfig") -> None:
|
||||
def reinitialize(self, config: FaceswapConfig) -> None:
|
||||
""" Reinitialize this :class:`Converter`.
|
||||
|
||||
Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the
|
||||
|
@ -127,7 +122,7 @@ class Converter():
|
|||
logger.debug("Reinitialized converter")
|
||||
|
||||
def _load_plugins(self,
|
||||
config: Optional["FaceswapConfig"] = None,
|
||||
config: FaceswapConfig | None = None,
|
||||
disable_logging: bool = False) -> None:
|
||||
""" Load the requested adjustment plugins.
|
||||
|
||||
|
@ -169,7 +164,7 @@ class Converter():
|
|||
self._adjustments.sharpening = sharpening
|
||||
logger.debug("Loaded plugins: %s", self._adjustments)
|
||||
|
||||
def process(self, in_queue: "EventQueue", out_queue: "EventQueue"):
|
||||
def process(self, in_queue: EventQueue, out_queue: EventQueue):
|
||||
""" Main convert process.
|
||||
|
||||
Takes items from the in queue, runs the relevant adjustments, patches faces to final frame
|
||||
|
@ -188,7 +183,7 @@ class Converter():
|
|||
in_queue, out_queue)
|
||||
log_once = False
|
||||
while True:
|
||||
inbound: Union[Literal["EOF"], "ConvertItem", List["ConvertItem"]] = in_queue.get()
|
||||
inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get()
|
||||
if inbound == "EOF":
|
||||
logger.debug("EOF Received")
|
||||
logger.debug("Patch queue finished")
|
||||
|
@ -218,7 +213,7 @@ class Converter():
|
|||
out_queue.put((item.inbound.filename, image))
|
||||
logger.debug("Completed convert process")
|
||||
|
||||
def _patch_image(self, predicted: "ConvertItem") -> Union[np.ndarray, List[bytes]]:
|
||||
def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]:
|
||||
""" Patch a swapped face onto a frame.
|
||||
|
||||
Run selected adjustments and swap the faces in a frame.
|
||||
|
@ -246,15 +241,15 @@ class Converter():
|
|||
out=np.empty(patched_face.shape, dtype="uint8"),
|
||||
casting='unsafe')
|
||||
if self._writer_pre_encode is None:
|
||||
retval: Union[np.ndarray, List[bytes]] = patched_face
|
||||
retval: np.ndarray | list[bytes] = patched_face
|
||||
else:
|
||||
retval = self._writer_pre_encode(patched_face)
|
||||
logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore
|
||||
return retval
|
||||
|
||||
def _get_new_image(self,
|
||||
predicted: "ConvertItem",
|
||||
frame_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
predicted: ConvertItem,
|
||||
frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Get the new face from the predictor and apply pre-warp manipulations.
|
||||
|
||||
Applies any requested adjustments to the raw output of the Faceswap model
|
||||
|
@ -308,9 +303,9 @@ class Converter():
|
|||
|
||||
def _pre_warp_adjustments(self,
|
||||
new_face: np.ndarray,
|
||||
detected_face: "DetectedFace",
|
||||
reference_face: "AlignedFace",
|
||||
predicted_mask: Optional[np.ndarray]) -> np.ndarray:
|
||||
detected_face: DetectedFace,
|
||||
reference_face: AlignedFace,
|
||||
predicted_mask: np.ndarray | None) -> np.ndarray:
|
||||
""" Run any requested adjustments that can be performed on the raw output from the Faceswap
|
||||
model.
|
||||
|
||||
|
@ -337,7 +332,7 @@ class Converter():
|
|||
"""
|
||||
logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore
|
||||
new_face.shape, predicted_mask.shape if predicted_mask is not None else None)
|
||||
old_face = cast(np.ndarray, reference_face.face)[..., :3] / 255.0
|
||||
old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0
|
||||
new_face, raw_mask = self._get_image_mask(new_face,
|
||||
detected_face,
|
||||
predicted_mask,
|
||||
|
@ -351,9 +346,9 @@ class Converter():
|
|||
|
||||
def _get_image_mask(self,
|
||||
new_face: np.ndarray,
|
||||
detected_face: "DetectedFace",
|
||||
predicted_mask: Optional[np.ndarray],
|
||||
reference_face: "AlignedFace") -> Tuple[np.ndarray, np.ndarray]:
|
||||
detected_face: DetectedFace,
|
||||
predicted_mask: np.ndarray | None,
|
||||
reference_face: AlignedFace) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Return any selected image mask
|
||||
|
||||
Places the requested mask into the new face's Alpha channel.
|
||||
|
|
|
@ -5,11 +5,10 @@ from the :class:`_GPUStats` class contained here. """
|
|||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from lib.utils import get_backend
|
||||
|
||||
_EXCLUDE_DEVICES: List[int] = []
|
||||
_EXCLUDE_DEVICES: list[int] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -29,11 +28,11 @@ class GPUInfo():
|
|||
devices_active: list[int]
|
||||
List of integers representing the indices of the active GPU devices.
|
||||
"""
|
||||
vram: List[int]
|
||||
vram_free: List[int]
|
||||
vram: list[int]
|
||||
vram_free: list[int]
|
||||
driver: str
|
||||
devices: List[str]
|
||||
devices_active: List[int]
|
||||
devices: list[str]
|
||||
devices_active: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -57,7 +56,7 @@ class BiggestGPUInfo():
|
|||
total: float
|
||||
|
||||
|
||||
def set_exclude_devices(devices: List[int]) -> None:
|
||||
def set_exclude_devices(devices: list[int]) -> None:
|
||||
""" Add any explicitly selected GPU devices to the global list of devices to be excluded
|
||||
from use by Faceswap.
|
||||
|
||||
|
@ -89,19 +88,19 @@ class _GPUStats():
|
|||
def __init__(self, log: bool = True) -> None:
|
||||
# Logger is held internally, as we don't want to log when obtaining system stats on crash
|
||||
# or when querying the backend for command line options
|
||||
self._logger: Optional[logging.Logger] = logging.getLogger(__name__) if log else None
|
||||
self._logger: logging.Logger | None = logging.getLogger(__name__) if log else None
|
||||
self._log("debug", f"Initializing {self.__class__.__name__}")
|
||||
|
||||
self._is_initialized = False
|
||||
self._initialize()
|
||||
|
||||
self._device_count: int = self._get_device_count()
|
||||
self._active_devices: List[int] = self._get_active_devices()
|
||||
self._active_devices: list[int] = self._get_active_devices()
|
||||
self._handles: list = self._get_handles()
|
||||
self._driver: str = self._get_driver()
|
||||
self._device_names: List[str] = self._get_device_names()
|
||||
self._vram: List[int] = self._get_vram()
|
||||
self._vram_free: List[int] = self._get_free_vram()
|
||||
self._device_names: list[str] = self._get_device_names()
|
||||
self._vram: list[int] = self._get_vram()
|
||||
self._vram_free: list[int] = self._get_free_vram()
|
||||
|
||||
if get_backend() != "cpu" and not self._active_devices:
|
||||
self._log("warning", "No GPU detected")
|
||||
|
@ -115,7 +114,7 @@ class _GPUStats():
|
|||
return self._device_count
|
||||
|
||||
@property
|
||||
def cli_devices(self) -> List[str]:
|
||||
def cli_devices(self) -> list[str]:
|
||||
""" list[str]: Formatted index: name text string for each GPU """
|
||||
return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)]
|
||||
|
||||
|
@ -167,7 +166,7 @@ class _GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_active_devices(self) -> List[int]:
|
||||
def _get_active_devices(self) -> list[int]:
|
||||
""" Obtain the indices of active GPUs (those that have not been explicitly excluded in
|
||||
the command line arguments).
|
||||
|
||||
|
@ -204,7 +203,7 @@ class _GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Override to obtain the names of all connected GPUs. The quality of this information
|
||||
depends on the backend and OS being used, but it should be sufficient for identifying
|
||||
cards.
|
||||
|
@ -217,7 +216,7 @@ class _GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Override to obtain the total VRAM in Megabytes for each connected GPU.
|
||||
|
||||
Returns
|
||||
|
@ -228,7 +227,7 @@ class _GPUStats():
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Override to obtain the amount of VRAM that is available, in Megabytes, for each
|
||||
connected GPU.
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """
|
||||
from typing import Any, List
|
||||
import typing as T
|
||||
|
||||
import os
|
||||
import psutil
|
||||
|
@ -35,7 +35,7 @@ class AppleSiliconStats(_GPUStats):
|
|||
"""
|
||||
def __init__(self, log: bool = True) -> None:
|
||||
# Following attribute set in :func:``_initialize``
|
||||
self._tf_devices: List[Any] = []
|
||||
self._tf_devices: list[T.Any] = []
|
||||
|
||||
super().__init__(log=log)
|
||||
|
||||
|
@ -142,7 +142,7 @@ class AppleSiliconStats(_GPUStats):
|
|||
self._log("debug", f"GPU Driver: {driver}")
|
||||
return driver
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of available Apple Silicon SoC(s) as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -155,7 +155,7 @@ class AppleSiliconStats(_GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -175,7 +175,7 @@ class AppleSiliconStats(_GPUStats):
|
|||
self._log("debug", f"SoC RAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each available Apple
|
||||
Silicon SoC.
|
||||
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Dummy functions for running faceswap on CPU. """
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
||||
|
@ -65,7 +61,7 @@ class CPUStats(_GPUStats):
|
|||
self._log("debug", f"GPU Driver: {driver}")
|
||||
return driver
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
|
||||
|
||||
Returns
|
||||
|
@ -73,11 +69,11 @@ class CPUStats(_GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
names: List[str] = []
|
||||
names: list[str] = []
|
||||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the RAM in Megabytes for the running system.
|
||||
|
||||
Returns
|
||||
|
@ -85,11 +81,11 @@ class CPUStats(_GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
vram: List[int] = []
|
||||
vram: list[int] = []
|
||||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of RAM that is available, in Megabytes, for the running system.
|
||||
|
||||
Returns
|
||||
|
@ -97,6 +93,6 @@ class CPUStats(_GPUStats):
|
|||
list
|
||||
An empty list for CPU backends
|
||||
"""
|
||||
vram: List[int] = []
|
||||
vram: list[int] = []
|
||||
self._log("debug", f"GPU VRAM free: {vram}")
|
||||
return vram
|
||||
|
|
|
@ -1,19 +1,23 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Collects and returns Information on DirectX 12 hardware devices for DirectML. """
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
assert sys.platform == "win32"
|
||||
|
||||
import ctypes
|
||||
from ctypes import POINTER, Structure, windll
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Any, Callable, cast, List
|
||||
|
||||
from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error
|
||||
|
||||
from ._base import _GPUStats
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
# Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types
|
||||
# We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we
|
||||
# attach it directly and suck up the typing errors.
|
||||
|
@ -314,7 +318,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
self._adapters = self._get_adapters()
|
||||
self._devices = self._process_adapters()
|
||||
|
||||
self._valid_adaptors: List[Device] = []
|
||||
self._valid_adaptors: list[Device] = []
|
||||
self._log("debug", f"Initialized {self.__class__.__name__}")
|
||||
|
||||
def _get_factory(self) -> ctypes._Pointer:
|
||||
|
@ -334,12 +338,12 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
factory_func.restype = HRESULT
|
||||
handle = ctypes.c_void_p(0)
|
||||
factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access
|
||||
retval = ctypes.POINTER(IDXGIFactory6)(cast(IDXGIFactory6, handle.value))
|
||||
retval = ctypes.POINTER(IDXGIFactory6)(T.cast(IDXGIFactory6, handle.value))
|
||||
self._log("debug", f"factory: {retval}")
|
||||
return retval
|
||||
|
||||
@property
|
||||
def valid_adapters(self) -> List[Device]:
|
||||
def valid_adapters(self) -> list[Device]:
|
||||
""" list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """
|
||||
if self._valid_adaptors:
|
||||
return self._valid_adaptors
|
||||
|
@ -354,7 +358,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
self._log("debug", f"valid_adaptors: {self._valid_adaptors}")
|
||||
return self._valid_adaptors
|
||||
|
||||
def _get_adapters(self) -> List[ctypes._Pointer]:
|
||||
def _get_adapters(self) -> list[ctypes._Pointer]:
|
||||
""" Obtain DirectX 12 supporting hardware adapter objects and add a Device class for
|
||||
obtaining details
|
||||
|
||||
|
@ -376,7 +380,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
if success != 0:
|
||||
raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: "
|
||||
f"{hex(ctypes.c_ulong(success).value)}")
|
||||
adapter = POINTER(IDXGIAdapter3)(cast(IDXGIAdapter3, handle.value))
|
||||
adapter = POINTER(IDXGIAdapter3)(T.cast(IDXGIAdapter3, handle.value))
|
||||
self._log("debug", f"found adapter: {adapter}")
|
||||
retval.append(adapter)
|
||||
except COMError as err:
|
||||
|
@ -392,7 +396,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
self._log("debug", f"adapters: {retval}")
|
||||
return retval
|
||||
|
||||
def _query_adapter(self, func: Callable[[Any], Any], *args: Any) -> None:
|
||||
def _query_adapter(self, func: Callable[[T.Any], T.Any], *args: T.Any) -> None:
|
||||
""" Query an adapter function, logging if the HRESULT is not a success
|
||||
|
||||
Parameters
|
||||
|
@ -430,7 +434,7 @@ class Adapters(): # pylint:disable=too-few-public-methods
|
|||
LookupGUID.ID3D12Device)
|
||||
return success in (0, 1)
|
||||
|
||||
def _process_adapters(self) -> List[Device]:
|
||||
def _process_adapters(self) -> list[Device]:
|
||||
""" Process the adapters to add discovered information.
|
||||
|
||||
Returns
|
||||
|
@ -485,21 +489,21 @@ class DirectML(_GPUStats):
|
|||
Default: ``True``
|
||||
"""
|
||||
def __init__(self, log: bool = True) -> None:
|
||||
self._devices: List[Device] = []
|
||||
self._devices: list[Device] = []
|
||||
super().__init__(log=log)
|
||||
|
||||
@property
|
||||
def _all_vram(self) -> List[int]:
|
||||
def _all_vram(self) -> list[int]:
|
||||
""" list: The VRAM of each GPU device that the DX API has discovered. """
|
||||
return [int(device.description.DedicatedVideoMemory / (1024 * 1024))
|
||||
for device in self._devices]
|
||||
|
||||
@property
|
||||
def names(self) -> List[str]:
|
||||
def names(self) -> list[str]:
|
||||
""" list: The name of each GPU device that the DX API has discovered. """
|
||||
return [device.description.Description for device in self._devices]
|
||||
|
||||
def _get_active_devices(self) -> List[int]:
|
||||
def _get_active_devices(self) -> list[int]:
|
||||
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
|
||||
DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
|
||||
arguments).
|
||||
|
@ -517,7 +521,7 @@ class DirectML(_GPUStats):
|
|||
self._log("debug", f"Active GPU Devices: {devices}")
|
||||
return devices
|
||||
|
||||
def _get_devices(self) -> List[Device]:
|
||||
def _get_devices(self) -> list[Device]:
|
||||
""" Obtain all detected DX API devices.
|
||||
|
||||
Returns
|
||||
|
@ -582,7 +586,7 @@ class DirectML(_GPUStats):
|
|||
self._log("debug", f"GPU Drivers: {drivers}")
|
||||
return drivers
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
|
||||
|
||||
Returns
|
||||
|
@ -594,7 +598,7 @@ class DirectML(_GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -607,7 +611,7 @@ class DirectML(_GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX
|
||||
12 supporting GPU.
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Collects and returns Information on available Nvidia GPUs. """
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pynvml
|
||||
|
||||
|
@ -83,7 +82,7 @@ class NvidiaStats(_GPUStats):
|
|||
self._log("debug", f"GPU Device count: {retval}")
|
||||
return retval
|
||||
|
||||
def _get_active_devices(self) -> List[int]:
|
||||
def _get_active_devices(self) -> list[int]:
|
||||
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
|
||||
CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
|
||||
arguments).
|
||||
|
@ -130,7 +129,7 @@ class NvidiaStats(_GPUStats):
|
|||
self._log("debug", f"GPU Driver: {driver}")
|
||||
return driver
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
|
||||
|
||||
Returns
|
||||
|
@ -143,7 +142,7 @@ class NvidiaStats(_GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -157,7 +156,7 @@ class NvidiaStats(_GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
|
||||
GPU.
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Collects and returns Information on available Nvidia GPUs connected to Apple Macs. """
|
||||
from typing import List
|
||||
|
||||
import pynvx
|
||||
|
||||
from lib.utils import FaceswapError
|
||||
|
@ -92,7 +90,7 @@ class NvidiaAppleStats(_GPUStats):
|
|||
self._log("debug", f"GPU Driver: {driver}")
|
||||
return driver
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`.
|
||||
|
||||
Returns
|
||||
|
@ -105,7 +103,7 @@ class NvidiaAppleStats(_GPUStats):
|
|||
self._log("debug", f"GPU Devices: {names}")
|
||||
return names
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -120,7 +118,7 @@ class NvidiaAppleStats(_GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {vram}")
|
||||
return vram
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia
|
||||
GPU.
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ It is a good starting point but may need to be refined over time
|
|||
import os
|
||||
import re
|
||||
from subprocess import run
|
||||
from typing import List
|
||||
|
||||
from ._base import _GPUStats
|
||||
|
||||
|
@ -221,7 +220,7 @@ class ROCm(_GPUStats):
|
|||
"""
|
||||
def __init__(self, log: bool = True) -> None:
|
||||
self._vendor_id = "0x1002" # AMD VendorID
|
||||
self._sysfs_paths: List[str] = []
|
||||
self._sysfs_paths: list[str] = []
|
||||
super().__init__(log=log)
|
||||
|
||||
def _from_sysfs_file(self, path: str) -> str:
|
||||
|
@ -249,7 +248,7 @@ class ROCm(_GPUStats):
|
|||
val = ""
|
||||
return val
|
||||
|
||||
def _get_sysfs_paths(self) -> List[str]:
|
||||
def _get_sysfs_paths(self) -> list[str]:
|
||||
""" Obtain a list of sysfs paths to AMD branded GPUs connected to the system
|
||||
|
||||
Returns
|
||||
|
@ -259,7 +258,7 @@ class ROCm(_GPUStats):
|
|||
"""
|
||||
base_dir = "/sys/class/drm/"
|
||||
|
||||
retval: List[str] = []
|
||||
retval: list[str] = []
|
||||
if not os.path.exists(base_dir):
|
||||
self._log("warning", f"sysfs not found at '{base_dir}'")
|
||||
return retval
|
||||
|
@ -347,7 +346,7 @@ class ROCm(_GPUStats):
|
|||
self._log("debug", f"GPU Drivers: {retval}")
|
||||
return retval
|
||||
|
||||
def _get_device_names(self) -> List[str]:
|
||||
def _get_device_names(self) -> list[str]:
|
||||
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
|
||||
|
||||
Returns
|
||||
|
@ -383,7 +382,7 @@ class ROCm(_GPUStats):
|
|||
self._log("debug", f"Device names: {retval}")
|
||||
return retval
|
||||
|
||||
def _get_active_devices(self) -> List[int]:
|
||||
def _get_active_devices(self) -> list[int]:
|
||||
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
|
||||
HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
|
||||
arguments).
|
||||
|
@ -401,7 +400,7 @@ class ROCm(_GPUStats):
|
|||
self._log("debug", f"Active GPU Devices: {devices}")
|
||||
return devices
|
||||
|
||||
def _get_vram(self) -> List[int]:
|
||||
def _get_vram(self) -> list[int]:
|
||||
""" Obtain the VRAM in Megabytes for each connected AMD GPU as identified in
|
||||
:attr:`_handles`.
|
||||
|
||||
|
@ -423,7 +422,7 @@ class ROCm(_GPUStats):
|
|||
self._log("debug", f"GPU VRAM: {retval}")
|
||||
return retval
|
||||
|
||||
def _get_free_vram(self) -> List[int]:
|
||||
def _get_free_vram(self) -> list[int]:
|
||||
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD
|
||||
GPU.
|
||||
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Handles the loading and collation of events from Tensorflow event log files. """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
import zlib
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast, Dict, Iterator, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -17,11 +16,8 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
|||
|
||||
from lib.serializer import get_serializer
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterator
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
@ -38,7 +34,7 @@ class EventData:
|
|||
The loss values collected for A and B sides for the event step
|
||||
"""
|
||||
timestamp: float = 0.0
|
||||
loss: List[float] = field(default_factory=list)
|
||||
loss: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class _LogFiles():
|
||||
|
@ -56,11 +52,11 @@ class _LogFiles():
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def session_ids(self) -> List[int]:
|
||||
def session_ids(self) -> list[int]:
|
||||
""" list[int]: Sorted list of `ints` of available session ids. """
|
||||
return list(sorted(self._filenames))
|
||||
|
||||
def _get_log_filenames(self) -> Dict[int, str]:
|
||||
def _get_log_filenames(self) -> dict[int, str]:
|
||||
""" Get the Tensorflow event filenames for all existing sessions.
|
||||
|
||||
Returns
|
||||
|
@ -69,7 +65,7 @@ class _LogFiles():
|
|||
The full path of each log file for each training session id that has been run
|
||||
"""
|
||||
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
|
||||
retval: Dict[int, str] = {}
|
||||
retval: dict[int, str] = {}
|
||||
for dirpath, _, filenames in os.walk(self._logs_folder):
|
||||
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
|
||||
continue
|
||||
|
@ -82,7 +78,7 @@ class _LogFiles():
|
|||
return retval
|
||||
|
||||
@classmethod
|
||||
def _get_session_id(cls, folder: str) -> Optional[int]:
|
||||
def _get_session_id(cls, folder: str) -> int | None:
|
||||
""" Obtain the session id for the given folder.
|
||||
|
||||
Parameters
|
||||
|
@ -103,7 +99,7 @@ class _LogFiles():
|
|||
return retval
|
||||
|
||||
@classmethod
|
||||
def _get_log_filename(cls, folder: str, filenames: List[str]) -> str:
|
||||
def _get_log_filename(cls, folder: str, filenames: list[str]) -> str:
|
||||
""" Obtain the session log file for the given folder. If multiple log files exist for the
|
||||
given folder, then the most recent log file is used, as earlier files are assumed to be
|
||||
obsolete.
|
||||
|
@ -161,10 +157,10 @@ class _CacheData():
|
|||
loss: :class:`np.ndarray`
|
||||
The loss values collected for A and B sides for the session
|
||||
"""
|
||||
def __init__(self, labels: List[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
|
||||
def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
|
||||
self.labels = labels
|
||||
self._loss = zlib.compress(cast(bytes, loss))
|
||||
self._timestamps = zlib.compress(cast(bytes, timestamps))
|
||||
self._loss = zlib.compress(T.cast(bytes, loss))
|
||||
self._timestamps = zlib.compress(T.cast(bytes, timestamps))
|
||||
self._timestamps_shape = timestamps.shape
|
||||
self._loss_shape = loss.shape
|
||||
|
||||
|
@ -192,8 +188,8 @@ class _CacheData():
|
|||
timestamps: :class:`numpy.ndarray`
|
||||
The latest timestamps to add to the cache
|
||||
"""
|
||||
new_buffer: List[bytes] = []
|
||||
new_shapes: List[Tuple[int, ...]] = []
|
||||
new_buffer: list[bytes] = []
|
||||
new_shapes: list[tuple[int, ...]] = []
|
||||
for data, buffer, dtype, shape in zip([timestamps, loss],
|
||||
[self._timestamps, self._loss],
|
||||
["float64", "float32"],
|
||||
|
@ -220,9 +216,9 @@ class _Cache():
|
|||
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
self._data: Dict[int, _CacheData] = {}
|
||||
self._carry_over: Dict[int, EventData] = {}
|
||||
self._loss_labels: List[str] = []
|
||||
self._data: dict[int, _CacheData] = {}
|
||||
self._carry_over: dict[int, EventData] = {}
|
||||
self._loss_labels: list[str] = []
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def is_cached(self, session_id: int) -> bool:
|
||||
|
@ -242,8 +238,8 @@ class _Cache():
|
|||
|
||||
def cache_data(self,
|
||||
session_id: int,
|
||||
data: Dict[int, EventData],
|
||||
labels: List[str],
|
||||
data: dict[int, EventData],
|
||||
labels: list[str],
|
||||
is_live: bool = False) -> None:
|
||||
""" Add a full session's worth of event data to :attr:`_data`.
|
||||
|
||||
|
@ -278,8 +274,8 @@ class _Cache():
|
|||
self._add_latest_live(session_id, loss, timestamps)
|
||||
|
||||
def _to_numpy(self,
|
||||
data: Dict[int, EventData],
|
||||
is_live: bool) -> Tuple[np.ndarray, np.ndarray]:
|
||||
data: dict[int, EventData],
|
||||
is_live: bool) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Extract each individual step data into separate numpy arrays for loss and timestamps.
|
||||
|
||||
Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays
|
||||
|
@ -333,7 +329,7 @@ class _Cache():
|
|||
|
||||
return n_times, n_loss
|
||||
|
||||
def _collect_carry_over(self, data: Dict[int, EventData]) -> None:
|
||||
def _collect_carry_over(self, data: dict[int, EventData]) -> None:
|
||||
""" For live data, collect carried over data from the previous update and merge into the
|
||||
current data dictionary.
|
||||
|
||||
|
@ -357,8 +353,8 @@ class _Cache():
|
|||
logger.debug("Merged carry over data: %s", update)
|
||||
|
||||
def _process_data(self,
|
||||
data: Dict[int, EventData],
|
||||
is_live: bool) -> Tuple[List[float], List[List[float]]]:
|
||||
data: dict[int, EventData],
|
||||
is_live: bool) -> tuple[list[float], list[list[float]]]:
|
||||
""" Process live update data.
|
||||
|
||||
Live data requires different processing as often we will only have partial data for the
|
||||
|
@ -383,8 +379,8 @@ class _Cache():
|
|||
timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss)
|
||||
for idx in sorted(data)])
|
||||
|
||||
l_loss: List[List[float]] = list(loss)
|
||||
l_timestamps: List[float] = list(timestamps)
|
||||
l_loss: list[list[float]] = list(loss)
|
||||
l_timestamps: list[float] = list(timestamps)
|
||||
|
||||
if len(l_loss[-1]) != len(self._loss_labels):
|
||||
logger.debug("Truncated loss found. loss count: %s", len(l_loss))
|
||||
|
@ -418,8 +414,8 @@ class _Cache():
|
|||
|
||||
self._data[session_id].add_live_data(timestamps, loss)
|
||||
|
||||
def get_data(self, session_id: int, metric: Literal["loss", "timestamps"]
|
||||
) -> Optional[Dict[int, Dict[str, Union[np.ndarray, List[str]]]]]:
|
||||
def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"]
|
||||
) -> dict[int, dict[str, np.ndarray | list[str]]] | None:
|
||||
""" Retrieve the decompressed cached data from the cache for the given session id.
|
||||
|
||||
Parameters
|
||||
|
@ -445,10 +441,10 @@ class _Cache():
|
|||
return None
|
||||
raw = {session_id: data}
|
||||
|
||||
retval: Dict[int, Dict[str, Union[np.ndarray, List[str]]]] = {}
|
||||
retval: dict[int, dict[str, np.ndarray | list[str]]] = {}
|
||||
for idx, data in raw.items():
|
||||
array = data.loss if metric == "loss" else data.timestamps
|
||||
val: Dict[str, Union[np.ndarray, List[str]]] = {str(metric): array}
|
||||
val: dict[str, np.ndarray | list[str]] = {str(metric): array}
|
||||
if metric == "loss":
|
||||
val["labels"] = data.labels
|
||||
retval[idx] = val
|
||||
|
@ -488,7 +484,7 @@ class TensorBoardLogs():
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def session_ids(self) -> List[int]:
|
||||
def session_ids(self) -> list[int]:
|
||||
""" list[int]: Sorted list of integers of available session ids. """
|
||||
return self._log_files.session_ids
|
||||
|
||||
|
@ -539,7 +535,7 @@ class TensorBoardLogs():
|
|||
parser = _EventParser(iterator, self._cache, live_data)
|
||||
parser.cache_events(session_id)
|
||||
|
||||
def _check_cache(self, session_id: Optional[int] = None) -> None:
|
||||
def _check_cache(self, session_id: int | None = None) -> None:
|
||||
""" Check if the given session_id has been cached and if not, cache it.
|
||||
|
||||
Parameters
|
||||
|
@ -557,7 +553,7 @@ class TensorBoardLogs():
|
|||
if not self._cache.is_cached(idx):
|
||||
self._cache_data(idx)
|
||||
|
||||
def get_loss(self, session_id: Optional[int] = None) -> Dict[int, Dict[str, np.ndarray]]:
|
||||
def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]:
|
||||
""" Read the loss from the TensorBoard event logs
|
||||
|
||||
Parameters
|
||||
|
@ -573,7 +569,7 @@ class TensorBoardLogs():
|
|||
and list of loss values for each step
|
||||
"""
|
||||
logger.debug("Getting loss: (session_id: %s)", session_id)
|
||||
retval: Dict[int, Dict[str, np.ndarray]] = {}
|
||||
retval: dict[int, dict[str, np.ndarray]] = {}
|
||||
for idx in [session_id] if session_id else self.session_ids:
|
||||
self._check_cache(idx)
|
||||
full_data = self._cache.get_data(idx, "loss")
|
||||
|
@ -588,7 +584,7 @@ class TensorBoardLogs():
|
|||
for key, val in retval.items()})
|
||||
return retval
|
||||
|
||||
def get_timestamps(self, session_id: Optional[int] = None) -> Dict[int, np.ndarray]:
|
||||
def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]:
|
||||
""" Read the timestamps from the TensorBoard logs.
|
||||
|
||||
As loss timestamps are slightly different for each loss, we collect the timestamp from the
|
||||
|
@ -608,7 +604,7 @@ class TensorBoardLogs():
|
|||
|
||||
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
|
||||
session_id, self._is_training)
|
||||
retval: Dict[int, np.ndarray] = {}
|
||||
retval: dict[int, np.ndarray] = {}
|
||||
for idx in [session_id] if session_id else self.session_ids:
|
||||
self._check_cache(idx)
|
||||
data = self._cache.get_data(idx, "timestamps")
|
||||
|
@ -640,7 +636,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
self._live_data = live_data
|
||||
self._cache = cache
|
||||
self._iterator = self._get_latest_live(iterator) if live_data else iterator
|
||||
self._loss_labels: List[str] = []
|
||||
self._loss_labels: list[str] = []
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@classmethod
|
||||
|
@ -683,7 +679,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
The session id that the data is being cached for
|
||||
"""
|
||||
assert self._iterator is not None
|
||||
data: Dict[int, EventData] = {}
|
||||
data: dict[int, EventData] = {}
|
||||
try:
|
||||
for record in self._iterator:
|
||||
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
|
||||
|
@ -743,7 +739,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Collated loss labels: %s", self._loss_labels)
|
||||
|
||||
@classmethod
|
||||
def _get_outputs(cls, model_config: Dict[str, Any]) -> np.ndarray:
|
||||
def _get_outputs(cls, model_config: dict[str, T.Any]) -> np.ndarray:
|
||||
""" Obtain the output names, instance index and output index for the given model.
|
||||
|
||||
If there is only a single output, the shape of the array is expanded to remain consistent
|
||||
|
|
|
@ -5,15 +5,15 @@ Holds the globally loaded training session. This will either be a user selected
|
|||
the analysis tab) or the currently training session.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import time
|
||||
import typing as T
|
||||
import warnings
|
||||
|
||||
from math import ceil
|
||||
from threading import Event
|
||||
from typing import Any, cast, Dict, List, Optional, overload, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -31,12 +31,12 @@ class GlobalSession():
|
|||
"""
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
self._state: Dict[str, Any] = {}
|
||||
self._state: dict[str, T.Any] = {}
|
||||
self._model_dir = ""
|
||||
self._model_name = ""
|
||||
|
||||
self._tb_logs: Optional[TensorBoardLogs] = None
|
||||
self._summary: Optional[SessionsSummary] = None
|
||||
self._tb_logs: TensorBoardLogs | None = None
|
||||
self._summary: SessionsSummary | None = None
|
||||
|
||||
self._is_training = False
|
||||
self._is_querying = Event()
|
||||
|
@ -60,7 +60,7 @@ class GlobalSession():
|
|||
return os.path.join(self._model_dir, self._model_name)
|
||||
|
||||
@property
|
||||
def batch_sizes(self) -> Dict[int, int]:
|
||||
def batch_sizes(self) -> dict[int, int]:
|
||||
""" dict: The batch sizes for each session_id for the model. """
|
||||
if not self._state:
|
||||
return {}
|
||||
|
@ -68,7 +68,7 @@ class GlobalSession():
|
|||
for sess_id, sess in self._state.get("sessions", {}).items()}
|
||||
|
||||
@property
|
||||
def full_summary(self) -> List[dict]:
|
||||
def full_summary(self) -> list[dict]:
|
||||
""" list: List of dictionaries containing summary statistics for each session id. """
|
||||
assert self._summary is not None
|
||||
return self._summary.get_summary_stats()
|
||||
|
@ -83,7 +83,7 @@ class GlobalSession():
|
|||
return self._state["sessions"][max_id]["no_logs"]
|
||||
|
||||
@property
|
||||
def session_ids(self) -> List[int]:
|
||||
def session_ids(self) -> list[int]:
|
||||
""" list: The sorted list of all existing session ids in the state file """
|
||||
if self._tb_logs is None:
|
||||
return []
|
||||
|
@ -164,7 +164,7 @@ class GlobalSession():
|
|||
|
||||
self._is_training = False
|
||||
|
||||
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]:
|
||||
def get_loss(self, session_id: int | None) -> dict[str, np.ndarray]:
|
||||
""" Obtain the loss values for the given session_id.
|
||||
|
||||
Parameters
|
||||
|
@ -186,11 +186,11 @@ class GlobalSession():
|
|||
assert self._tb_logs is not None
|
||||
loss_dict = self._tb_logs.get_loss(session_id=session_id)
|
||||
if session_id is None:
|
||||
all_loss: Dict[str, List[float]] = {}
|
||||
all_loss: dict[str, list[float]] = {}
|
||||
for key in sorted(loss_dict):
|
||||
for loss_key, loss in loss_dict[key].items():
|
||||
all_loss.setdefault(loss_key, []).extend(loss)
|
||||
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
|
||||
retval: dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
|
||||
for key, val in all_loss.items()}
|
||||
else:
|
||||
retval = loss_dict.get(session_id, {})
|
||||
|
@ -199,11 +199,11 @@ class GlobalSession():
|
|||
self._is_querying.clear()
|
||||
return retval
|
||||
|
||||
@overload
|
||||
def get_timestamps(self, session_id: None) -> Dict[int, np.ndarray]:
|
||||
@T.overload
|
||||
def get_timestamps(self, session_id: None) -> dict[int, np.ndarray]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@T.overload
|
||||
def get_timestamps(self, session_id: int) -> np.ndarray:
|
||||
...
|
||||
|
||||
|
@ -247,7 +247,7 @@ class GlobalSession():
|
|||
continue
|
||||
break
|
||||
|
||||
def get_loss_keys(self, session_id: Optional[int]) -> List[str]:
|
||||
def get_loss_keys(self, session_id: int | None) -> list[str]:
|
||||
""" Obtain the loss keys for the given session_id.
|
||||
|
||||
Parameters
|
||||
|
@ -268,7 +268,7 @@ class GlobalSession():
|
|||
in self._tb_logs.get_loss(session_id=session_id).items()}
|
||||
|
||||
if session_id is None:
|
||||
retval: List[str] = list(set(loss_key
|
||||
retval: list[str] = list(set(loss_key
|
||||
for session in loss_keys.values()
|
||||
for loss_key in session))
|
||||
else:
|
||||
|
@ -293,11 +293,11 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
self._session = session
|
||||
self._state = session._state
|
||||
|
||||
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {}
|
||||
self._per_session_stats: List[Dict[str, Any]] = []
|
||||
self._time_stats: dict[int, dict[str, float | int]] = {}
|
||||
self._per_session_stats: list[dict[str, T.Any]] = []
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def get_summary_stats(self) -> List[dict]:
|
||||
def get_summary_stats(self) -> list[dict]:
|
||||
""" Compile the individual session statistics and calculate the total.
|
||||
|
||||
Format the stats for display
|
||||
|
@ -336,14 +336,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0,
|
||||
"end_time": np.max(timestamps) if np.any(timestamps) else 0,
|
||||
"iterations": timestamps.shape[0] if np.any(timestamps) else 0}
|
||||
for sess_id, timestamps in cast(Dict[int, np.ndarray],
|
||||
self._session.get_timestamps(None)).items()}
|
||||
for sess_id, timestamps in T.cast(dict[int, np.ndarray],
|
||||
self._session.get_timestamps(None)).items()}
|
||||
|
||||
elif _SESSION.is_training:
|
||||
logger.debug("Updating summary time stamps for training session")
|
||||
|
||||
session_id = _SESSION.session_ids[-1]
|
||||
latest = cast(np.ndarray, self._session.get_timestamps(session_id))
|
||||
latest = T.cast(np.ndarray, self._session.get_timestamps(session_id))
|
||||
|
||||
self._time_stats[session_id] = {
|
||||
"start_time": np.min(latest) if np.any(latest) else 0,
|
||||
|
@ -392,7 +392,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
/ stats["elapsed"] if stats["elapsed"] > 0 else 0)
|
||||
logger.debug("per_session_stats: %s", self._per_session_stats)
|
||||
|
||||
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]:
|
||||
def _collate_stats(self, session_id: int) -> dict[str, int | float]:
|
||||
""" Collate the session summary statistics for the given session ID.
|
||||
|
||||
Parameters
|
||||
|
@ -422,7 +422,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
logger.debug(retval)
|
||||
return retval
|
||||
|
||||
def _total_stats(self) -> Dict[str, Union[str, int, float]]:
|
||||
def _total_stats(self) -> dict[str, str | int | float]:
|
||||
""" Compile the Totals stats.
|
||||
Totals are fully calculated each time as they will change on the basis of the training
|
||||
session.
|
||||
|
@ -459,7 +459,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
logger.debug(totals)
|
||||
return totals
|
||||
|
||||
def _format_stats(self, compiled_stats: List[dict]) -> List[dict]:
|
||||
def _format_stats(self, compiled_stats: list[dict]) -> list[dict]:
|
||||
""" Format for the incoming list of statistics for display.
|
||||
|
||||
Parameters
|
||||
|
@ -489,7 +489,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
return retval
|
||||
|
||||
@classmethod
|
||||
def _convert_time(cls, timestamp: float) -> Tuple[str, str, str]:
|
||||
def _convert_time(cls, timestamp: float) -> tuple[str, str, str]:
|
||||
""" Convert time stamp to total hours, minutes and seconds.
|
||||
|
||||
Parameters
|
||||
|
@ -534,8 +534,8 @@ class Calculations():
|
|||
"""
|
||||
def __init__(self, session_id,
|
||||
display: str = "loss",
|
||||
loss_keys: Union[List[str], str] = "loss",
|
||||
selections: Union[List[str], str] = "raw",
|
||||
loss_keys: list[str] | str = "loss",
|
||||
selections: list[str] | str = "raw",
|
||||
avg_samples: int = 500,
|
||||
smooth_amount: float = 0.90,
|
||||
flatten_outliers: bool = False) -> None:
|
||||
|
@ -552,13 +552,13 @@ class Calculations():
|
|||
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
|
||||
self._selections = selections if isinstance(selections, list) else [selections]
|
||||
self._is_totals = session_id is None
|
||||
self._args: Dict[str, Union[int, float]] = {"avg_samples": avg_samples,
|
||||
"smooth_amount": smooth_amount,
|
||||
"flatten_outliers": flatten_outliers}
|
||||
self._args: dict[str, int | float] = {"avg_samples": avg_samples,
|
||||
"smooth_amount": smooth_amount,
|
||||
"flatten_outliers": flatten_outliers}
|
||||
self._iterations = 0
|
||||
self._limit = 0
|
||||
self._start_iteration = 0
|
||||
self._stats: Dict[str, np.ndarray] = {}
|
||||
self._stats: dict[str, np.ndarray] = {}
|
||||
self.refresh()
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
@ -573,11 +573,11 @@ class Calculations():
|
|||
return self._start_iteration
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, np.ndarray]:
|
||||
def stats(self) -> dict[str, np.ndarray]:
|
||||
""" dict: The final calculated statistics """
|
||||
return self._stats
|
||||
|
||||
def refresh(self) -> Optional["Calculations"]:
|
||||
def refresh(self) -> Calculations | None:
|
||||
""" Refresh the stats """
|
||||
logger.debug("Refreshing")
|
||||
if not _SESSION.is_loaded:
|
||||
|
@ -736,7 +736,8 @@ class Calculations():
|
|||
"""
|
||||
logger.debug("Calculating rate")
|
||||
batch_size = _SESSION.batch_sizes[self._session_id] * 2
|
||||
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id)))
|
||||
retval = batch_size / np.diff(T.cast(np.ndarray,
|
||||
_SESSION.get_timestamps(self._session_id)))
|
||||
logger.debug("Calculated rate: Item_count: %s", len(retval))
|
||||
return retval
|
||||
|
||||
|
@ -757,7 +758,7 @@ class Calculations():
|
|||
logger.debug("Calculating totals rate")
|
||||
batchsizes = _SESSION.batch_sizes
|
||||
total_timestamps = _SESSION.get_timestamps(None)
|
||||
rate: List[float] = []
|
||||
rate: list[float] = []
|
||||
for sess_id in sorted(total_timestamps.keys()):
|
||||
batchsize = batchsizes[sess_id]
|
||||
timestamps = total_timestamps[sess_id]
|
||||
|
@ -797,7 +798,7 @@ class Calculations():
|
|||
The moving average for the given data
|
||||
"""
|
||||
logger.debug("Calculating Average. Data points: %s", len(data))
|
||||
window = cast(int, self._args["avg_samples"])
|
||||
window = T.cast(int, self._args["avg_samples"])
|
||||
pad = ceil(window / 2)
|
||||
datapoints = data.shape[0]
|
||||
|
||||
|
@ -953,7 +954,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
|
|||
def _ewma_vectorized(self,
|
||||
data: np.ndarray,
|
||||
out: np.ndarray,
|
||||
offset: Optional[float] = None) -> None:
|
||||
offset: float | None = None) -> None:
|
||||
""" Calculates the exponential moving average over a vector. Will fail for large inputs.
|
||||
|
||||
The result is processed in place into the array passed to the `out` parameter
|
||||
|
|
|
@ -5,10 +5,10 @@ import logging
|
|||
import re
|
||||
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
from tkinter import colorchooser, ttk
|
||||
from itertools import zip_longest
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
|
||||
from _tkinter import Tcl_Obj, TclError
|
||||
|
||||
|
@ -24,7 +24,9 @@ _ = _LANG.gettext
|
|||
# We store Tooltips, ContextMenus and Commands globally when they are created
|
||||
# Because we need to add them back to newly cloned widgets (they are not easily accessible from
|
||||
# original config or are prone to getting destroyed when the original widget is destroyed)
|
||||
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={})
|
||||
_RECREATE_OBJECTS: dict[str, dict[str, T.Any]] = {"tooltips": {},
|
||||
"commands": {},
|
||||
"contextmenus": {}}
|
||||
|
||||
|
||||
def _get_tooltip(widget, text=None, text_variable=None):
|
||||
|
@ -154,17 +156,17 @@ class ControlPanelOption():
|
|||
self.dtype = dtype
|
||||
self.sysbrowser = sysbrowser
|
||||
self._command = command
|
||||
self._options = dict(title=title,
|
||||
subgroup=subgroup,
|
||||
group=group,
|
||||
default=default,
|
||||
initial_value=initial_value,
|
||||
choices=choices,
|
||||
is_radio=is_radio,
|
||||
is_multi_option=is_multi_option,
|
||||
rounding=rounding,
|
||||
min_max=min_max,
|
||||
helptext=helptext)
|
||||
self._options = {"title": title,
|
||||
"subgroup": subgroup,
|
||||
"group": group,
|
||||
"default": default,
|
||||
"initial_value": initial_value,
|
||||
"choices": choices,
|
||||
"is_radio": is_radio,
|
||||
"is_multi_option": is_multi_option,
|
||||
"rounding": rounding,
|
||||
"min_max": min_max,
|
||||
"helptext": helptext}
|
||||
self.control = self.get_control()
|
||||
self.tk_var = self.get_tk_var(initial_value, track_modified)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
@ -421,7 +423,7 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
|
|||
self.group_frames = {}
|
||||
self._sub_group_frames = {}
|
||||
|
||||
canvas_kwargs = dict(bd=0, highlightthickness=0, bg=self._theme["panel_background"])
|
||||
canvas_kwargs = {"bd": 0, "highlightthickness": 0, "bg": self._theme["panel_background"]}
|
||||
|
||||
self._canvas = tk.Canvas(self, **canvas_kwargs)
|
||||
self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||
|
@ -525,8 +527,8 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors
|
|||
|
||||
group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW)
|
||||
|
||||
self.group_frames[group] = dict(frame=retval,
|
||||
chkbtns=self.checkbuttons_frame(retval))
|
||||
self.group_frames[group] = {"frame": retval,
|
||||
"chkbtns": self.checkbuttons_frame(retval)}
|
||||
group_frame = self.group_frames[group]
|
||||
return group_frame
|
||||
|
||||
|
@ -720,12 +722,12 @@ class AutoFillContainer():
|
|||
"""
|
||||
retval = {}
|
||||
if widget.__class__.__name__ == "MultiOption":
|
||||
retval = dict(value=widget._value, # pylint:disable=protected-access
|
||||
variable=widget._master_variable) # pylint:disable=protected-access
|
||||
retval = {"value": widget._value, # pylint:disable=protected-access
|
||||
"variable": widget._master_variable} # pylint:disable=protected-access
|
||||
elif widget.__class__.__name__ == "ToggledFrame":
|
||||
# Toggled Frames need to have their variable tracked
|
||||
retval = dict(text=widget._text, # pylint:disable=protected-access
|
||||
toggle_var=widget._toggle_var) # pylint:disable=protected-access
|
||||
retval = {"text": widget._text, # pylint:disable=protected-access
|
||||
"toggle_var": widget._toggle_var} # pylint:disable=protected-access
|
||||
return retval
|
||||
|
||||
def get_all_children_config(self, widget, child_list):
|
||||
|
@ -988,7 +990,7 @@ class ControlBuilder():
|
|||
if self.option.control != ttk.Checkbutton:
|
||||
ctl.pack(padx=5, pady=5, fill=tk.X, expand=True)
|
||||
if self.option.helptext is not None and not self.helpset:
|
||||
tooltip_kwargs = dict(text=self.option.helptext)
|
||||
tooltip_kwargs = {"text": self.option.helptext}
|
||||
if self.option.sysbrowser is not None:
|
||||
tooltip_kwargs["text_variable"] = self.option.tk_var
|
||||
_get_tooltip(ctl, **tooltip_kwargs)
|
||||
|
@ -1071,7 +1073,7 @@ class ControlBuilder():
|
|||
"rounding: %s, min_max: %s)", self.option.name, self.option.dtype,
|
||||
self.option.rounding, self.option.min_max)
|
||||
validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float
|
||||
vcmd = (self.frame.register(validate))
|
||||
vcmd = self.frame.register(validate)
|
||||
tbox = tk.Entry(self.frame,
|
||||
width=8,
|
||||
textvariable=self.option.tk_var,
|
||||
|
@ -1246,15 +1248,15 @@ class FileBrowser():
|
|||
@property
|
||||
def helptext(self):
|
||||
""" Dict containing tooltip text for buttons """
|
||||
retval = dict(folder=_("Select a folder..."),
|
||||
load=_("Select a file..."),
|
||||
load2=_("Select a file..."),
|
||||
picture=_("Select a folder of images..."),
|
||||
video=_("Select a video..."),
|
||||
model=_("Select a model folder..."),
|
||||
multi_load=_("Select one or more files..."),
|
||||
context=_("Select a file or folder..."),
|
||||
save_as=_("Select a save location..."))
|
||||
retval = {"folder": _("Select a folder..."),
|
||||
"load": _("Select a file..."),
|
||||
"load2": _("Select a file..."),
|
||||
"picture": _("Select a folder of images..."),
|
||||
"video": _("Select a video..."),
|
||||
"model": _("Select a model folder..."),
|
||||
"multi_load": _("Select one or more files..."),
|
||||
"context": _("Select a file or folder..."),
|
||||
"save_as": _("Select a save location...")}
|
||||
return retval
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -4,11 +4,10 @@ import datetime
|
|||
import gettext
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
|
||||
from tkinter import ttk
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from lib.training.preview_tk import PreviewTk
|
||||
|
||||
|
@ -19,11 +18,6 @@ from .analysis import Calculations, Session
|
|||
from .control_helper import set_slider_rounding
|
||||
from .utils import FileHandler, get_config, get_images, preview_trigger
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# LOCALES
|
||||
|
@ -92,7 +86,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
logger.debug("Initializing %s (args: %s, kwargs: %s)",
|
||||
self.__class__.__name__, args, kwargs)
|
||||
self._preview = get_images().preview_train
|
||||
self._display: Optional[PreviewTk] = None
|
||||
self._display: PreviewTk | None = None
|
||||
super().__init__(*args, **kwargs)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
@ -177,9 +171,9 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
tab_name: str,
|
||||
helptext: str,
|
||||
wait_time: int,
|
||||
command: Optional[str] = None) -> None:
|
||||
self._trace_vars: Dict[Literal["smoothgraph", "display_iterations"],
|
||||
Tuple[tk.BooleanVar, str]] = {}
|
||||
command: str | None = None) -> None:
|
||||
self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
|
||||
tuple[tk.BooleanVar, str]] = {}
|
||||
super().__init__(parent, tab_name, helptext, wait_time, command)
|
||||
|
||||
def set_vars(self) -> None:
|
||||
|
@ -446,7 +440,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
|
||||
def _add_trace_variables(self) -> None:
|
||||
""" Add tracing for when the option sliders are updated, for updating the graph. """
|
||||
for name, action in zip(get_args(Literal["smoothgraph", "display_iterations"]),
|
||||
for name, action in zip(T.get_args(T.Literal["smoothgraph", "display_iterations"]),
|
||||
(self._smooth_amount_callback, self._iteration_limit_callback)):
|
||||
var = self.vars[name]
|
||||
if name not in self._trace_vars:
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
#!/usr/bin python3
|
||||
""" Graph functions for Display Frame area of the Faceswap GUI """
|
||||
from __future__ import annotations
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
|
||||
from tkinter import ttk
|
||||
from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING
|
||||
from math import ceil, floor
|
||||
|
||||
import numpy as np
|
||||
|
@ -20,7 +21,7 @@ from matplotlib.backend_bases import NavigationToolbar2
|
|||
from .custom_widgets import Tooltip
|
||||
from .utils import get_config, get_images, LongRunningTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
matplotlib.use("TkAgg")
|
||||
|
@ -49,8 +50,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self._ylabel = ylabel
|
||||
self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper",
|
||||
"summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"]
|
||||
self._lines: List["Line2D"] = []
|
||||
self._toolbar: Optional["NavigationToolbar"] = None
|
||||
self._lines: list[Line2D] = []
|
||||
self._toolbar: "NavigationToolbar" | None = None
|
||||
self._fig = Figure(figsize=(4, 4), dpi=75)
|
||||
|
||||
self._ax1 = self._fig.add_subplot(1, 1, 1)
|
||||
|
@ -129,7 +130,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self._ax1.set_ylim(0.00, 100.0)
|
||||
self._ax1.set_xlim(0, 1)
|
||||
|
||||
def _axes_limits_set(self, data: List[float]) -> None:
|
||||
def _axes_limits_set(self, data: list[float]) -> None:
|
||||
""" Set the axes limits.
|
||||
|
||||
Parameters
|
||||
|
@ -154,7 +155,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self._axes_limits_set_default()
|
||||
|
||||
@staticmethod
|
||||
def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]:
|
||||
def _axes_data_get_min_max(data: list[float]) -> tuple[float, float]:
|
||||
""" Obtain the minimum and maximum values for the y-axis from the given data points.
|
||||
|
||||
Parameters
|
||||
|
@ -188,7 +189,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
logger.debug("yscale: '%s'", scale)
|
||||
self._ax1.set_yscale(scale)
|
||||
|
||||
def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]:
|
||||
def _lines_sort(self, keys: list[str]) -> list[list[str | int | tuple[float]]]:
|
||||
""" Sort the data keys into consistent order and set line color map and line width.
|
||||
|
||||
Parameters
|
||||
|
@ -202,8 +203,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
A list of loss keys with their corresponding line formatting and color information
|
||||
"""
|
||||
logger.trace("Sorting lines") # type:ignore[attr-defined]
|
||||
raw_lines: List[List[str]] = []
|
||||
sorted_lines: List[List[str]] = []
|
||||
raw_lines: list[list[str]] = []
|
||||
sorted_lines: list[list[str]] = []
|
||||
for key in sorted(keys):
|
||||
title = key.replace("_", " ").title()
|
||||
if key.startswith("raw"):
|
||||
|
@ -217,7 +218,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int:
|
||||
def _lines_groupsize(raw_lines: list[list[str]], sorted_lines: list[list[str]]) -> int:
|
||||
""" Get the number of items in each group.
|
||||
|
||||
If raw data isn't selected, then check the length of remaining groups until something is
|
||||
|
@ -246,8 +247,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
return groupsize
|
||||
|
||||
def _lines_style(self,
|
||||
lines: List[List[str]],
|
||||
groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]:
|
||||
lines: list[list[str]],
|
||||
groupsize: int) -> list[list[str | int | tuple[float]]]:
|
||||
""" Obtain the color map and line width for each group.
|
||||
|
||||
Parameters
|
||||
|
@ -266,13 +267,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
groups = int(len(lines) / groupsize)
|
||||
colours = self._lines_create_colors(groupsize, groups)
|
||||
widths = list(range(1, groups + 1))
|
||||
retval = cast(List[List[Union[str, int, Tuple[float]]]], lines)
|
||||
retval = T.cast(list[list[str | int | tuple[float]]], lines)
|
||||
for idx, item in enumerate(retval):
|
||||
linewidth = widths[idx // groupsize]
|
||||
item.extend((linewidth, colours[idx]))
|
||||
return retval
|
||||
|
||||
def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]:
|
||||
def _lines_create_colors(self, groupsize: int, groups: int) -> list[tuple[float]]:
|
||||
""" Create the color maps.
|
||||
|
||||
Parameters
|
||||
|
@ -336,8 +337,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
|
||||
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
|
||||
super().__init__(parent, data, ylabel)
|
||||
self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask
|
||||
self._displayed_keys: List[str] = []
|
||||
self._thread: LongRunningTask | None = None # Thread for LongRunningTask
|
||||
self._displayed_keys: list[str] = []
|
||||
self._add_callback()
|
||||
|
||||
def _add_callback(self) -> None:
|
||||
|
@ -352,7 +353,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
|
||||
def refresh(self, *args) -> None: # pylint: disable=unused-argument
|
||||
""" Read the latest loss data and apply to current graph """
|
||||
refresh_var = cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
|
||||
refresh_var = T.cast(tk.BooleanVar, get_config().tk_vars.refresh_graph)
|
||||
if not refresh_var.get() and self._thread is None:
|
||||
return
|
||||
|
||||
|
@ -533,7 +534,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
text: str,
|
||||
image_file: str,
|
||||
toggle: bool,
|
||||
command) -> Union[ttk.Button, ttk.Checkbutton]:
|
||||
command) -> ttk.Button | ttk.Checkbutton:
|
||||
""" Override the default button method to use our icons and ttk widgets for
|
||||
consistent GUI layout.
|
||||
|
||||
|
@ -563,10 +564,10 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
img = get_images().icons[icon]
|
||||
|
||||
if not toggle:
|
||||
btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame,
|
||||
text=text,
|
||||
image=img,
|
||||
command=command)
|
||||
btn: ttk.Button | ttk.Checkbutton = ttk.Button(frame,
|
||||
text=text,
|
||||
image=img,
|
||||
command=command)
|
||||
else:
|
||||
var = tk.IntVar(master=frame)
|
||||
btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin python3
|
||||
""" The Menu Bars for faceswap GUI """
|
||||
|
||||
from __future__ import annotations
|
||||
import gettext
|
||||
import locale
|
||||
import logging
|
||||
|
@ -33,7 +33,7 @@ _ = _LANG.gettext
|
|||
|
||||
_WORKING_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
|
||||
_RESOURCES: T.List[T.Tuple[str, str]] = [
|
||||
_RESOURCES: list[tuple[str, str]] = [
|
||||
(_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"),
|
||||
(_("Patreon - Support this project"), "https://www.patreon.com/faceswap"),
|
||||
(_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"),
|
||||
|
@ -48,7 +48,7 @@ class MainMenuBar(tk.Menu): # pylint:disable=too-many-ancestors
|
|||
master: :class:`tkinter.Tk`
|
||||
The root tkinter object
|
||||
"""
|
||||
def __init__(self, master: "FaceswapGui") -> None:
|
||||
def __init__(self, master: FaceswapGui) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
super().__init__(master)
|
||||
self.root = master
|
||||
|
@ -431,7 +431,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
|
|||
return True
|
||||
|
||||
@classmethod
|
||||
def _get_branches(cls) -> T.Optional[str]:
|
||||
def _get_branches(cls) -> str | None:
|
||||
""" Get the available github branches
|
||||
|
||||
Returns
|
||||
|
@ -453,7 +453,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors
|
|||
return stdout.decode(locale.getpreferredencoding(), errors="replace")
|
||||
|
||||
@classmethod
|
||||
def _filter_branches(cls, stdout: str) -> T.List[str]:
|
||||
def _filter_branches(cls, stdout: str) -> list[str]:
|
||||
""" Filter the branches, remove duplicates and the current branch and return a sorted
|
||||
list.
|
||||
|
||||
|
@ -548,7 +548,7 @@ class TaskBar(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self._section_separator()
|
||||
|
||||
@classmethod
|
||||
def _loader_and_kwargs(cls, btntype: str) -> T.Tuple[str, T.Dict[str, bool]]:
|
||||
def _loader_and_kwargs(cls, btntype: str) -> tuple[str, dict[str, bool]]:
|
||||
""" Get the loader name and key word arguments for the given button type
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin python3
|
||||
""" The pop-up window of the Faceswap GUI for the setting of configuration options. """
|
||||
|
||||
from __future__ import annotations
|
||||
from collections import OrderedDict
|
||||
from configparser import ConfigParser
|
||||
import gettext
|
||||
|
@ -9,7 +9,8 @@ import os
|
|||
import sys
|
||||
import tkinter as tk
|
||||
from tkinter import ttk
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from lib.serializer import get_serializer
|
||||
|
@ -18,7 +19,7 @@ from .control_helper import ControlPanel, ControlPanelOption
|
|||
from .custom_widgets import Tooltip
|
||||
from .utils import FileHandler, get_config, get_images, PATHCACHE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.config import FaceswapConfig
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
@ -124,7 +125,7 @@ class _ConfigurePlugins(tk.Toplevel):
|
|||
super().__init__()
|
||||
self._root = get_config().root
|
||||
self._set_geometry()
|
||||
self._tk_vars = dict(header=tk.StringVar())
|
||||
self._tk_vars = {"header": tk.StringVar()}
|
||||
|
||||
theme = {**get_config().user_theme["group_panel"],
|
||||
**get_config().user_theme["group_settings"]}
|
||||
|
@ -402,7 +403,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
|
|||
"""
|
||||
def __init__(self, top_level, parent, configurations, tree, theme):
|
||||
super().__init__(parent)
|
||||
self._configs: Dict[str, "FaceswapConfig"] = configurations
|
||||
self._configs: dict[str, FaceswapConfig] = configurations
|
||||
self._theme = theme
|
||||
self._tree = tree
|
||||
self._vars = {}
|
||||
|
@ -443,7 +444,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
|
|||
sect = section.split(".")[-1]
|
||||
# Elevate global to root
|
||||
key = plugin if sect == "global" else f"{plugin}|{category}|{sect}"
|
||||
retval[key] = dict(helptext=None, options=OrderedDict())
|
||||
retval[key] = {"helptext": None, "options": OrderedDict()}
|
||||
|
||||
retval[key]["helptext"] = conf.defaults[section].helptext
|
||||
for option, params in conf.defaults[section].items.items():
|
||||
|
@ -632,7 +633,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors
|
|||
|
||||
def _get_new_config(self,
|
||||
page_only: bool,
|
||||
config: "FaceswapConfig",
|
||||
config: FaceswapConfig,
|
||||
category: str,
|
||||
lookup: str) -> ConfigParser:
|
||||
""" Obtain a new configuration file for saving
|
||||
|
@ -812,9 +813,9 @@ class _Presets():
|
|||
return None
|
||||
|
||||
args = ("save_filename", "json") if action == "save" else ("filename", "json")
|
||||
kwargs = dict(title=f"{action.title()} Preset...",
|
||||
initial_folder=self._preset_path,
|
||||
parent=self._parent)
|
||||
kwargs = {"title": f"{action.title()} Preset...",
|
||||
"initial_folder": self._preset_path,
|
||||
"parent": self._parent}
|
||||
if action == "save":
|
||||
kwargs["initial_file"] = self._get_initial_filename()
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import tkinter as tk
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
from tkinter import ttk
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from .control_helper import ControlBuilder, ControlPanelOption
|
||||
from .custom_widgets import Tooltip
|
||||
|
@ -66,7 +65,7 @@ class SessionTKVars:
|
|||
outliers: tk.BooleanVar
|
||||
avgiterations: tk.IntVar
|
||||
smoothamount: tk.DoubleVar
|
||||
loss_keys: Dict[str, tk.BooleanVar] = field(default_factory=dict)
|
||||
loss_keys: dict[str, tk.BooleanVar] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SessionPopUp(tk.Toplevel):
|
||||
|
@ -82,13 +81,13 @@ class SessionPopUp(tk.Toplevel):
|
|||
logger.debug("Initializing: %s: (session_id: %s, data_points: %s)",
|
||||
self.__class__.__name__, session_id, data_points)
|
||||
super().__init__()
|
||||
self._thread: Optional[LongRunningTask] = None # Thread for loading data in background
|
||||
self._thread: LongRunningTask | None = None # Thread for loading data in background
|
||||
self._default_view = "avg" if data_points > 1000 else "smoothed"
|
||||
self._session_id = None if session_id == "Total" else int(session_id)
|
||||
|
||||
self._graph_frame = ttk.Frame(self)
|
||||
self._graph: Optional[SessionGraph] = None
|
||||
self._display_data: Optional[Calculations] = None
|
||||
self._graph: SessionGraph | None = None
|
||||
self._display_data: Calculations | None = None
|
||||
|
||||
self._vars = self._set_vars()
|
||||
|
||||
|
@ -172,7 +171,7 @@ class SessionPopUp(tk.Toplevel):
|
|||
The frame that the options reside in
|
||||
"""
|
||||
logger.debug("Building Combo boxes")
|
||||
choices = dict(Display=("Loss", "Rate"), Scale=("Linear", "Log"))
|
||||
choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")}
|
||||
|
||||
for item in ["Display", "Scale"]:
|
||||
var: tk.StringVar = getattr(self._vars, item.lower())
|
||||
|
@ -273,11 +272,11 @@ class SessionPopUp(tk.Toplevel):
|
|||
logger.debug("Building Slider Controls")
|
||||
for item in ("avgiterations", "smoothamount"):
|
||||
if item == "avgiterations":
|
||||
dtype: Union[Type[int], Type[float]] = int
|
||||
dtype: type[int] | type[float] = int
|
||||
text = "Iterations to Average:"
|
||||
default: Union[int, float] = 500
|
||||
default: int | float = 500
|
||||
rounding = 25
|
||||
min_max: Tuple[int, Union[int, float]] = (25, 2500)
|
||||
min_max: tuple[int, int | float] = (25, 2500)
|
||||
elif item == "smoothamount":
|
||||
dtype = float
|
||||
text = "Smoothing Amount:"
|
||||
|
@ -404,20 +403,20 @@ class SessionPopUp(tk.Toplevel):
|
|||
str
|
||||
The help text for the given action
|
||||
"""
|
||||
lookup = dict(
|
||||
reload=_("Refresh graph"),
|
||||
save=_("Save display data to csv"),
|
||||
avgiterations=_("Number of data points to sample for rolling average"),
|
||||
smoothamount=_("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum "
|
||||
"smoothing"),
|
||||
outliers=_("Flatten data points that fall more than 1 standard deviation from the "
|
||||
"mean to the mean value."),
|
||||
avg=_("Display rolling average of the data"),
|
||||
smoothed=_("Smooth the data"),
|
||||
raw=_("Display raw data"),
|
||||
trend=_("Display polynormal data trend"),
|
||||
display=_("Set the data to display"),
|
||||
scale=_("Change y-axis scale"))
|
||||
lookup = {
|
||||
"reload": _("Refresh graph"),
|
||||
"save": _("Save display data to csv"),
|
||||
"avgiterations": _("Number of data points to sample for rolling average"),
|
||||
"smoothamount": _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum "
|
||||
"smoothing"),
|
||||
"outliers": _("Flatten data points that fall more than 1 standard deviation from the "
|
||||
"mean to the mean value."),
|
||||
"avg": _("Display rolling average of the data"),
|
||||
"smoothed": _("Smooth the data"),
|
||||
"raw": _("Display raw data"),
|
||||
"trend": _("Display polynormal data trend"),
|
||||
"display": _("Set the data to display"),
|
||||
"scale": _("Change y-axis scale")}
|
||||
return lookup.get(action.lower(), "")
|
||||
|
||||
def _compile_display_data(self) -> bool:
|
||||
|
@ -446,13 +445,13 @@ class SessionPopUp(tk.Toplevel):
|
|||
self._lbl_loading.pack(fill=tk.BOTH, expand=True)
|
||||
self.update_idletasks()
|
||||
|
||||
kwargs = dict(session_id=self._session_id,
|
||||
display=self._vars.display.get(),
|
||||
loss_keys=loss_keys,
|
||||
selections=selections,
|
||||
avg_samples=self._vars.avgiterations.get(),
|
||||
smooth_amount=self._vars.smoothamount.get(),
|
||||
flatten_outliers=self._vars.outliers.get())
|
||||
kwargs = {"session_id": self._session_id,
|
||||
"display": self._vars.display.get(),
|
||||
"loss_keys": loss_keys,
|
||||
"selections": selections,
|
||||
"avg_samples": self._vars.avgiterations.get(),
|
||||
"smooth_amount": self._vars.smoothamount.get(),
|
||||
"flatten_outliers": self._vars.outliers.get()}
|
||||
self._thread = LongRunningTask(target=self._get_display_data,
|
||||
kwargs=kwargs,
|
||||
widget=self)
|
||||
|
@ -491,7 +490,7 @@ class SessionPopUp(tk.Toplevel):
|
|||
"""
|
||||
return Calculations(**kwargs)
|
||||
|
||||
def _check_valid_selection(self, loss_keys: List[str], selections: List[str]) -> bool:
|
||||
def _check_valid_selection(self, loss_keys: list[str], selections: list[str]) -> bool:
|
||||
""" Check that there will be data to display.
|
||||
|
||||
Parameters
|
||||
|
@ -530,7 +529,7 @@ class SessionPopUp(tk.Toplevel):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _selections_to_list(self) -> List[str]:
|
||||
def _selections_to_list(self) -> list[str]:
|
||||
""" Compile checkbox selections to a list.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
#!/usr/bin python3
|
||||
""" Global configuration optiopns for the Faceswap GUI """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from lib.gui._config import Config as UserConfig
|
||||
from lib.gui.project import Project, Tasks
|
||||
from lib.gui.theme import Style
|
||||
from .file_handler import FileHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.gui.options import CliOptions
|
||||
from lib.gui.custom_widgets import StatusBar
|
||||
from lib.gui.command import CommandNotebook
|
||||
|
@ -22,12 +23,12 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
PATHCACHE = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache")
|
||||
_CONFIG: Optional["Config"] = None
|
||||
_CONFIG: Config | None = None
|
||||
|
||||
|
||||
def initialize_config(root: tk.Tk,
|
||||
cli_opts: Optional["CliOptions"],
|
||||
statusbar: Optional["StatusBar"]) -> Optional["Config"]:
|
||||
cli_opts: CliOptions | None,
|
||||
statusbar: StatusBar | None) -> Config | None:
|
||||
""" Initialize the GUI Master :class:`Config` and add to global constant.
|
||||
|
||||
This should only be called once on first GUI startup. Future access to :class:`Config`
|
||||
|
@ -145,13 +146,13 @@ class GlobalVariables():
|
|||
@dataclass
|
||||
class _GuiObjects:
|
||||
""" Data class for commonly accessed GUI Objects """
|
||||
cli_opts: Optional["CliOptions"]
|
||||
cli_opts: CliOptions | None
|
||||
tk_vars: GlobalVariables
|
||||
project: Project
|
||||
tasks: Tasks
|
||||
status_bar: Optional["StatusBar"]
|
||||
default_options: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
command_notebook: Optional["CommandNotebook"] = None
|
||||
status_bar: StatusBar | None
|
||||
default_options: dict[str, dict[str, T.Any]] = field(default_factory=dict)
|
||||
command_notebook: CommandNotebook | None = None
|
||||
|
||||
|
||||
class Config():
|
||||
|
@ -172,15 +173,15 @@ class Config():
|
|||
"""
|
||||
def __init__(self,
|
||||
root: tk.Tk,
|
||||
cli_opts: Optional["CliOptions"],
|
||||
statusbar: Optional["StatusBar"]) -> None:
|
||||
cli_opts: CliOptions | None,
|
||||
statusbar: StatusBar | None) -> None:
|
||||
logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)",
|
||||
self.__class__.__name__, root, cli_opts, statusbar)
|
||||
self._default_font = cast(dict, tk.font.nametofont("TkDefaultFont").configure())["family"]
|
||||
self._constants = dict(
|
||||
root=root,
|
||||
scaling_factor=self._get_scaling(root),
|
||||
default_font=self._default_font)
|
||||
self._default_font = T.cast(dict,
|
||||
tk.font.nametofont("TkDefaultFont").configure())["family"]
|
||||
self._constants = {"root": root,
|
||||
"scaling_factor": self._get_scaling(root),
|
||||
"default_font": self._default_font}
|
||||
self._gui_objects = _GuiObjects(
|
||||
cli_opts=cli_opts,
|
||||
tk_vars=GlobalVariables(),
|
||||
|
@ -211,7 +212,7 @@ class Config():
|
|||
|
||||
# GUI Objects
|
||||
@property
|
||||
def cli_opts(self) -> "CliOptions":
|
||||
def cli_opts(self) -> CliOptions:
|
||||
""" :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """
|
||||
# This should only be None when a separate tool (not main GUI) is used, at which point
|
||||
# cli_opts do not exist
|
||||
|
@ -234,12 +235,12 @@ class Config():
|
|||
return self._gui_objects.tasks
|
||||
|
||||
@property
|
||||
def default_options(self) -> Dict[str, Dict[str, Any]]:
|
||||
def default_options(self) -> dict[str, dict[str, T.Any]]:
|
||||
""" dict: The default options for all tabs """
|
||||
return self._gui_objects.default_options
|
||||
|
||||
@property
|
||||
def statusbar(self) -> "StatusBar":
|
||||
def statusbar(self) -> StatusBar:
|
||||
""" :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar
|
||||
:class:`tkinter.ttk.Frame`. """
|
||||
# This should only be None when a separate tool (not main GUI) is used, at which point
|
||||
|
@ -248,31 +249,31 @@ class Config():
|
|||
return self._gui_objects.status_bar
|
||||
|
||||
@property
|
||||
def command_notebook(self) -> Optional["CommandNotebook"]:
|
||||
def command_notebook(self) -> CommandNotebook | None:
|
||||
""" :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """
|
||||
return self._gui_objects.command_notebook
|
||||
|
||||
# Convenience GUI Objects
|
||||
@property
|
||||
def tools_notebook(self) -> "ToolsNotebook":
|
||||
def tools_notebook(self) -> ToolsNotebook:
|
||||
""" :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """
|
||||
assert self.command_notebook is not None
|
||||
return self.command_notebook.tools_notebook
|
||||
|
||||
@property
|
||||
def modified_vars(self) -> Dict[str, "tk.BooleanVar"]:
|
||||
def modified_vars(self) -> dict[str, tk.BooleanVar]:
|
||||
""" dict: The command notebook modified tkinter variables. """
|
||||
assert self.command_notebook is not None
|
||||
return self.command_notebook.modified_vars
|
||||
|
||||
@property
|
||||
def _command_tabs(self) -> Dict[str, int]:
|
||||
def _command_tabs(self) -> dict[str, int]:
|
||||
""" dict: Command tab titles with their IDs. """
|
||||
assert self.command_notebook is not None
|
||||
return self.command_notebook.tab_names
|
||||
|
||||
@property
|
||||
def _tools_tabs(self) -> Dict[str, int]:
|
||||
def _tools_tabs(self) -> dict[str, int]:
|
||||
""" dict: Tools command tab titles with their IDs. """
|
||||
assert self.command_notebook is not None
|
||||
return self.command_notebook.tools_tab_names
|
||||
|
@ -284,17 +285,17 @@ class Config():
|
|||
return self._user_config
|
||||
|
||||
@property
|
||||
def user_config_dict(self) -> Dict[str, Any]: # TODO Dataclass
|
||||
def user_config_dict(self) -> dict[str, T.Any]: # TODO Dataclass
|
||||
""" dict: The GUI config in dict form. """
|
||||
return self._user_config.config_dict
|
||||
|
||||
@property
|
||||
def user_theme(self) -> Dict[str, Any]: # TODO Dataclass
|
||||
def user_theme(self) -> dict[str, T.Any]: # TODO Dataclass
|
||||
""" dict: The GUI theme selection options. """
|
||||
return self._user_theme
|
||||
|
||||
@property
|
||||
def default_font(self) -> Tuple[str, int]:
|
||||
def default_font(self) -> tuple[str, int]:
|
||||
""" tuple: The selected font as configured in user settings. First item is the font (`str`)
|
||||
second item the font size (`int`). """
|
||||
font = self.user_config_dict["font"]
|
||||
|
@ -328,7 +329,7 @@ class Config():
|
|||
self._gui_objects.default_options = default
|
||||
self.project.set_default_options()
|
||||
|
||||
def set_command_notebook(self, notebook: "CommandNotebook") -> None:
|
||||
def set_command_notebook(self, notebook: CommandNotebook) -> None:
|
||||
""" Set the command notebook to the :attr:`command_notebook` attribute
|
||||
and enable the modified callback for :attr:`project`.
|
||||
|
||||
|
@ -385,7 +386,7 @@ class Config():
|
|||
""" Reload the user config from file. """
|
||||
self._user_config = UserConfig(None)
|
||||
|
||||
def set_cursor_busy(self, widget: Optional[tk.Widget] = None) -> None:
|
||||
def set_cursor_busy(self, widget: tk.Widget | None = None) -> None:
|
||||
""" Set the root or widget cursor to busy.
|
||||
|
||||
Parameters
|
||||
|
@ -399,7 +400,7 @@ class Config():
|
|||
component.config(cursor="watch") # type: ignore
|
||||
component.update_idletasks()
|
||||
|
||||
def set_cursor_default(self, widget: Optional[tk.Widget] = None) -> None:
|
||||
def set_cursor_default(self, widget: tk.Widget | None = None) -> None:
|
||||
""" Set the root or widget cursor to default.
|
||||
|
||||
Parameters
|
||||
|
@ -413,7 +414,7 @@ class Config():
|
|||
component.config(cursor="") # type: ignore
|
||||
component.update_idletasks()
|
||||
|
||||
def set_root_title(self, text: Optional[str] = None) -> None:
|
||||
def set_root_title(self, text: str | None = None) -> None:
|
||||
""" Set the main title text for Faceswap.
|
||||
|
||||
The title will always begin with 'Faceswap.py'. Additional text can be appended.
|
||||
|
|
|
@ -2,24 +2,15 @@
|
|||
""" File browser utility functions for the Faceswap GUI. """
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
|
||||
from typing import cast, Dict, IO, List, Optional, Tuple, Union
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
import typing as T
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FILETYPE = Literal["default", "alignments", "config_project", "config_task",
|
||||
"config_all", "csv", "image", "ini", "state", "log", "video"]
|
||||
_HANDLETYPE = Literal["open", "save", "filename", "filename_multi", "save_filename",
|
||||
"context", "dir"]
|
||||
_FILETYPE = T.Literal["default", "alignments", "config_project", "config_task",
|
||||
"config_all", "csv", "image", "ini", "state", "log", "video"]
|
||||
_HANDLETYPE = T.Literal["open", "save", "filename", "filename_multi", "save_filename",
|
||||
"context", "dir"]
|
||||
|
||||
|
||||
class FileHandler(): # pylint:disable=too-few-public-methods
|
||||
|
@ -72,14 +63,14 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
|
||||
def __init__(self,
|
||||
handle_type: _HANDLETYPE,
|
||||
file_type: Optional[_FILETYPE],
|
||||
title: Optional[str] = None,
|
||||
initial_folder: Optional[str] = None,
|
||||
initial_file: Optional[str] = None,
|
||||
command: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
variable: Optional[str] = None,
|
||||
parent: Optional[tk.Frame] = None) -> None:
|
||||
file_type: _FILETYPE | None,
|
||||
title: str | None = None,
|
||||
initial_folder: str | None = None,
|
||||
initial_file: str | None = None,
|
||||
command: str | None = None,
|
||||
action: str | None = None,
|
||||
variable: str | None = None,
|
||||
parent: tk.Frame | None = None) -> None:
|
||||
logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', "
|
||||
"initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', "
|
||||
"variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type,
|
||||
|
@ -101,35 +92,35 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def _filetypes(self) -> Dict[str, List[Tuple[str, str]]]:
|
||||
def _filetypes(self) -> dict[str, list[tuple[str, str]]]:
|
||||
""" dict: The accepted extensions for each file type for opening/saving """
|
||||
all_files = ("All files", "*.*")
|
||||
filetypes = dict(
|
||||
default=[all_files],
|
||||
alignments=[("Faceswap Alignments", "*.fsa"), all_files],
|
||||
config_project=[("Faceswap Project files", "*.fsw"), all_files],
|
||||
config_task=[("Faceswap Task files", "*.fst"), all_files],
|
||||
config_all=[("Faceswap Project and Task files", "*.fst *.fsw"), all_files],
|
||||
csv=[("Comma separated values", "*.csv"), all_files],
|
||||
image=[("Bitmap", "*.bmp"),
|
||||
("JPG", "*.jpeg *.jpg"),
|
||||
("PNG", "*.png"),
|
||||
("TIFF", "*.tif *.tiff"),
|
||||
all_files],
|
||||
ini=[("Faceswap config files", "*.ini"), all_files],
|
||||
json=[("JSON file", "*.json"), all_files],
|
||||
model=[("Keras model files", "*.h5"), all_files],
|
||||
state=[("State files", "*.json"), all_files],
|
||||
log=[("Log files", "*.log"), all_files],
|
||||
video=[("Audio Video Interleave", "*.avi"),
|
||||
("Flash Video", "*.flv"),
|
||||
("Matroska", "*.mkv"),
|
||||
("MOV", "*.mov"),
|
||||
("MP4", "*.mp4"),
|
||||
("MPEG", "*.mpeg *.mpg *.ts *.vob"),
|
||||
("WebM", "*.webm"),
|
||||
("Windows Media Video", "*.wmv"),
|
||||
all_files])
|
||||
filetypes = {
|
||||
"default": [all_files],
|
||||
"alignments": [("Faceswap Alignments", "*.fsa"), all_files],
|
||||
"config_project": [("Faceswap Project files", "*.fsw"), all_files],
|
||||
"config_task": [("Faceswap Task files", "*.fst"), all_files],
|
||||
"config_all": [("Faceswap Project and Task files", "*.fst *.fsw"), all_files],
|
||||
"csv": [("Comma separated values", "*.csv"), all_files],
|
||||
"image": [("Bitmap", "*.bmp"),
|
||||
("JPG", "*.jpeg *.jpg"),
|
||||
("PNG", "*.png"),
|
||||
("TIFF", "*.tif *.tiff"),
|
||||
all_files],
|
||||
"ini": [("Faceswap config files", "*.ini"), all_files],
|
||||
"json": [("JSON file", "*.json"), all_files],
|
||||
"model": [("Keras model files", "*.h5"), all_files],
|
||||
"state": [("State files", "*.json"), all_files],
|
||||
"log": [("Log files", "*.log"), all_files],
|
||||
"video": [("Audio Video Interleave", "*.avi"),
|
||||
("Flash Video", "*.flv"),
|
||||
("Matroska", "*.mkv"),
|
||||
("MOV", "*.mov"),
|
||||
("MP4", "*.mp4"),
|
||||
("MPEG", "*.mpeg *.mpg *.ts *.vob"),
|
||||
("WebM", "*.webm"),
|
||||
("Windows Media Video", "*.wmv"),
|
||||
all_files]}
|
||||
|
||||
# Add in multi-select options and upper case extensions for Linux
|
||||
for key in filetypes:
|
||||
|
@ -142,32 +133,32 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
multi = [f"{key.title()} Files"]
|
||||
multi.append(" ".join([ftype[1]
|
||||
for ftype in filetypes[key] if ftype[0] != "All files"]))
|
||||
filetypes[key].insert(0, cast(Tuple[str, str], tuple(multi)))
|
||||
filetypes[key].insert(0, T.cast(tuple[str, str], tuple(multi)))
|
||||
return filetypes
|
||||
|
||||
@property
|
||||
def _contexts(self) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]:
|
||||
def _contexts(self) -> dict[str, dict[str, str | dict[str, str]]]:
|
||||
"""dict: Mapping of commands, actions and their corresponding file dialog for context
|
||||
handle types. """
|
||||
return dict(effmpeg=dict(input={"extract": "filename",
|
||||
"gen-vid": "dir",
|
||||
"get-fps": "filename",
|
||||
"get-info": "filename",
|
||||
"mux-audio": "filename",
|
||||
"rescale": "filename",
|
||||
"rotate": "filename",
|
||||
"slice": "filename"},
|
||||
output={"extract": "dir",
|
||||
"gen-vid": "save_filename",
|
||||
"get-fps": "nothing",
|
||||
"get-info": "nothing",
|
||||
"mux-audio": "save_filename",
|
||||
"rescale": "save_filename",
|
||||
"rotate": "save_filename",
|
||||
"slice": "save_filename"}))
|
||||
return {"effmpeg": {"input": {"extract": "filename",
|
||||
"gen-vid": "dir",
|
||||
"get-fps": "filename",
|
||||
"get-info": "filename",
|
||||
"mux-audio": "filename",
|
||||
"rescale": "filename",
|
||||
"rotate": "filename",
|
||||
"slice": "filename"},
|
||||
"output": {"extract": "dir",
|
||||
"gen-vid": "save_filename",
|
||||
"get-fps": "nothing",
|
||||
"get-info": "nothing",
|
||||
"mux-audio": "save_filename",
|
||||
"rescale": "save_filename",
|
||||
"rotate": "save_filename",
|
||||
"slice": "save_filename"}}}
|
||||
|
||||
@classmethod
|
||||
def _set_dummy_master(cls) -> Optional[tk.Frame]:
|
||||
def _set_dummy_master(cls) -> tk.Frame | None:
|
||||
""" Add an option to force black font on Linux file dialogs KDE issue that displays light
|
||||
font on white background).
|
||||
|
||||
|
@ -183,7 +174,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
if platform.system().lower() == "linux":
|
||||
frame = tk.Frame()
|
||||
frame.option_add("*foreground", "black")
|
||||
retval: Optional[tk.Frame] = frame
|
||||
retval: tk.Frame | None = frame
|
||||
else:
|
||||
retval = None
|
||||
return retval
|
||||
|
@ -196,7 +187,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
del self._dummy_master
|
||||
self._dummy_master = None
|
||||
|
||||
def _set_defaults(self) -> Dict[str, Optional[str]]:
|
||||
def _set_defaults(self) -> dict[str, str | None]:
|
||||
""" Set the default file type for the file dialog. Generally the first found file type
|
||||
will be used, but this is overridden if it is not appropriate.
|
||||
|
||||
|
@ -205,7 +196,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
dict:
|
||||
The default file extension for each file type
|
||||
"""
|
||||
defaults: Dict[str, Optional[str]] = {
|
||||
defaults: dict[str, str | None] = {
|
||||
key: next(ext for ext in val[0][1].split(" ")).replace("*", "")
|
||||
for key, val in self._filetypes.items()}
|
||||
defaults["default"] = None
|
||||
|
@ -215,15 +206,15 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
return defaults
|
||||
|
||||
def _set_kwargs(self,
|
||||
title: Optional[str],
|
||||
initial_folder: Optional[str],
|
||||
initial_file: Optional[str],
|
||||
file_type: Optional[_FILETYPE],
|
||||
command: Optional[str],
|
||||
action: Optional[str],
|
||||
variable: Optional[str],
|
||||
parent: Optional[tk.Frame]
|
||||
) -> Dict[str, Union[None, tk.Frame, str, List[Tuple[str, str]]]]:
|
||||
title: str | None,
|
||||
initial_folder: str | None,
|
||||
initial_file: str | None,
|
||||
file_type: _FILETYPE | None,
|
||||
command: str | None,
|
||||
action: str | None,
|
||||
variable: str | None,
|
||||
parent: tk.Frame | None
|
||||
) -> dict[str, None | tk.Frame | str | list[tuple[str, str]]]:
|
||||
""" Generate the required kwargs for the requested file dialog browser.
|
||||
|
||||
Parameters
|
||||
|
@ -259,8 +250,8 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
title, initial_folder, initial_file, file_type, command, action, variable,
|
||||
parent)
|
||||
|
||||
kwargs: Dict[str, Union[None, tk.Frame, str,
|
||||
List[Tuple[str, str]]]] = dict(master=self._dummy_master)
|
||||
kwargs: dict[str, None | tk.Frame | str | list[tuple[str, str]]] = {
|
||||
"master": self._dummy_master}
|
||||
|
||||
if self._handletype.lower() == "context":
|
||||
assert command is not None and action is not None and variable is not None
|
||||
|
@ -304,20 +295,20 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
The variable associated with this file dialog
|
||||
"""
|
||||
if self._contexts[command].get(variable, None) is not None:
|
||||
handletype = cast(Dict[str, Dict[str, Dict[str, str]]],
|
||||
self._contexts)[command][variable][action]
|
||||
handletype = T.cast(dict[str, dict[str, dict[str, str]]],
|
||||
self._contexts)[command][variable][action]
|
||||
else:
|
||||
handletype = cast(Dict[str, Dict[str, str]],
|
||||
self._contexts)[command][action]
|
||||
handletype = T.cast(dict[str, dict[str, str]],
|
||||
self._contexts)[command][action]
|
||||
logger.debug(handletype)
|
||||
self._handletype = cast(_HANDLETYPE, handletype)
|
||||
self._handletype = T.cast(_HANDLETYPE, handletype)
|
||||
|
||||
def _open(self) -> Optional[IO]:
|
||||
def _open(self) -> T.IO | None:
|
||||
""" Open a file. """
|
||||
logger.debug("Popping Open browser")
|
||||
return filedialog.askopenfile(**self._kwargs) # type: ignore
|
||||
|
||||
def _save(self) -> Optional[IO]:
|
||||
def _save(self) -> T.IO | None:
|
||||
""" Save a file. """
|
||||
logger.debug("Popping Save browser")
|
||||
return filedialog.asksaveasfile(**self._kwargs) # type: ignore
|
||||
|
@ -337,7 +328,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Popping Filename browser")
|
||||
return filedialog.askopenfilename(**self._kwargs) # type: ignore
|
||||
|
||||
def _filename_multi(self) -> Tuple[str, ...]:
|
||||
def _filename_multi(self) -> tuple[str, ...]:
|
||||
""" Get multiple existing file locations. """
|
||||
logger.debug("Popping Filename browser")
|
||||
return filedialog.askopenfilenames(**self._kwargs) # type: ignore
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
#!/usr/bin python3
|
||||
""" Utilities for handling images in the Faceswap GUI """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import cast, Dict, List, Optional, Sequence, Tuple
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -14,15 +13,12 @@ from lib.training.preview_cv import PreviewBuffer
|
|||
|
||||
from .config import get_config, PATHCACHE
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_IMAGES: Optional["Images"] = None
|
||||
_PREVIEW_TRIGGER: Optional["PreviewTrigger"] = None
|
||||
_IMAGES: "Images" | None = None
|
||||
_PREVIEW_TRIGGER: "PreviewTrigger" | None = None
|
||||
TRAININGPREVIEW = ".gui_training_preview.png"
|
||||
|
||||
|
||||
|
@ -51,7 +47,7 @@ def get_images() -> "Images":
|
|||
return _IMAGES
|
||||
|
||||
|
||||
def _get_previews(image_path: str) -> List[str]:
|
||||
def _get_previews(image_path: str) -> list[str]:
|
||||
""" Get the images stored within the given directory.
|
||||
|
||||
Parameters
|
||||
|
@ -164,12 +160,12 @@ class PreviewExtract():
|
|||
self._output_path = ""
|
||||
|
||||
self._modified: float = 0.0
|
||||
self._filenames: List[str] = []
|
||||
self._images: Optional[np.ndarray] = None
|
||||
self._placeholder: Optional[np.ndarray] = None
|
||||
self._filenames: list[str] = []
|
||||
self._images: np.ndarray | None = None
|
||||
self._placeholder: np.ndarray | None = None
|
||||
|
||||
self._preview_image: Optional[Image.Image] = None
|
||||
self._preview_image_tk: Optional[ImageTk.PhotoImage] = None
|
||||
self._preview_image: Image.Image | None = None
|
||||
self._preview_image_tk: ImageTk.PhotoImage | None = None
|
||||
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
@ -228,7 +224,7 @@ class PreviewExtract():
|
|||
logger.debug("sorted folders: %s, return value: %s", folders, retval)
|
||||
return retval
|
||||
|
||||
def _get_newest_filenames(self, image_files: List[str]) -> List[str]:
|
||||
def _get_newest_filenames(self, image_files: list[str]) -> list[str]:
|
||||
""" Return image filenames that have been modified since the last check.
|
||||
|
||||
Parameters
|
||||
|
@ -281,8 +277,8 @@ class PreviewExtract():
|
|||
return retval
|
||||
|
||||
def _process_samples(self,
|
||||
samples: List[np.ndarray],
|
||||
filenames: List[str],
|
||||
samples: list[np.ndarray],
|
||||
filenames: list[str],
|
||||
num_images: int) -> bool:
|
||||
""" Process the latest sample images into a displayable image.
|
||||
|
||||
|
@ -321,8 +317,8 @@ class PreviewExtract():
|
|||
return True
|
||||
|
||||
def _load_images_to_cache(self,
|
||||
image_files: List[str],
|
||||
frame_dims: Tuple[int, int],
|
||||
image_files: list[str],
|
||||
frame_dims: tuple[int, int],
|
||||
thumbnail_size: int) -> bool:
|
||||
""" Load preview images to the image cache.
|
||||
|
||||
|
@ -349,7 +345,7 @@ class PreviewExtract():
|
|||
logger.debug("num_images: %s", num_images)
|
||||
if num_images == 0:
|
||||
return False
|
||||
samples: List[np.ndarray] = []
|
||||
samples: list[np.ndarray] = []
|
||||
start_idx = len(image_files) - num_images if len(image_files) > num_images else 0
|
||||
show_files = sorted(image_files, key=os.path.getctime)[start_idx:]
|
||||
dropped_files = []
|
||||
|
@ -405,7 +401,7 @@ class PreviewExtract():
|
|||
self._placeholder = placeholder
|
||||
logger.debug("Created placeholder. shape: %s", placeholder.shape)
|
||||
|
||||
def _place_previews(self, frame_dims: Tuple[int, int]) -> Image.Image:
|
||||
def _place_previews(self, frame_dims: tuple[int, int]) -> Image.Image:
|
||||
""" Format the preview thumbnails stored in the cache into a grid fitting the display
|
||||
panel.
|
||||
|
||||
|
@ -441,12 +437,12 @@ class PreviewExtract():
|
|||
placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder)
|
||||
samples = np.concatenate((samples, placeholder))
|
||||
|
||||
display = np.vstack([np.hstack(cast(Sequence, samples[row * cols: (row + 1) * cols]))
|
||||
display = np.vstack([np.hstack(T.cast("Sequence", samples[row * cols: (row + 1) * cols]))
|
||||
for row in range(rows)])
|
||||
logger.debug("display shape: %s", display.shape)
|
||||
return Image.fromarray(display)
|
||||
|
||||
def load_latest_preview(self, thumbnail_size: int, frame_dims: Tuple[int, int]) -> bool:
|
||||
def load_latest_preview(self, thumbnail_size: int, frame_dims: tuple[int, int]) -> bool:
|
||||
""" Load the latest preview image for extract and convert.
|
||||
|
||||
Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails
|
||||
|
@ -524,7 +520,7 @@ class Images():
|
|||
def __init__(self) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
self._pathpreview = os.path.join(PATHCACHE, "preview")
|
||||
self._pathoutput: Optional[str] = None
|
||||
self._pathoutput: str | None = None
|
||||
self._batch_mode = False
|
||||
self._preview_train = PreviewTrain(self._pathpreview)
|
||||
self._preview_extract = PreviewExtract(self._pathpreview)
|
||||
|
@ -542,7 +538,7 @@ class Images():
|
|||
return self._preview_extract
|
||||
|
||||
@property
|
||||
def icons(self) -> Dict[str, ImageTk.PhotoImage]:
|
||||
def icons(self) -> dict[str, ImageTk.PhotoImage]:
|
||||
""" dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon
|
||||
name (`str`) the value is the icon sized and formatted for display
|
||||
(:class:`PIL.ImageTK.PhotoImage`).
|
||||
|
@ -557,7 +553,7 @@ class Images():
|
|||
return self._icons
|
||||
|
||||
@staticmethod
|
||||
def _load_icons() -> Dict[str, ImageTk.PhotoImage]:
|
||||
def _load_icons() -> dict[str, ImageTk.PhotoImage]:
|
||||
""" Scan the icons cache folder and load the icons into :attr:`icons` for retrieval
|
||||
throughout the GUI.
|
||||
|
||||
|
@ -569,7 +565,7 @@ class Images():
|
|||
"""
|
||||
size = get_config().user_config_dict.get("icon_size", 16)
|
||||
size = int(round(size * get_config().scaling_factor))
|
||||
icons: Dict[str, ImageTk.PhotoImage] = {}
|
||||
icons: dict[str, ImageTk.PhotoImage] = {}
|
||||
pathicons = os.path.join(PATHCACHE, "icons")
|
||||
for fname in os.listdir(pathicons):
|
||||
name, ext = os.path.splitext(fname)
|
||||
|
@ -609,12 +605,12 @@ class PreviewTrigger():
|
|||
"""
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
self._trigger_files = dict(update=os.path.join(PATHCACHE, ".preview_trigger"),
|
||||
mask_toggle=os.path.join(PATHCACHE, ".preview_mask_toggle"))
|
||||
self._trigger_files = {"update": os.path.join(PATHCACHE, ".preview_trigger"),
|
||||
"mask_toggle": os.path.join(PATHCACHE, ".preview_mask_toggle")}
|
||||
logger.debug("Initialized: %s (trigger_files: %s)",
|
||||
self.__class__.__name__, self._trigger_files)
|
||||
|
||||
def set(self, trigger_type: Literal["update", "mask_toggle"]):
|
||||
def set(self, trigger_type: T.Literal["update", "mask_toggle"]):
|
||||
""" Place the trigger file into the cache folder
|
||||
|
||||
Parameters
|
||||
|
@ -629,7 +625,7 @@ class PreviewTrigger():
|
|||
pass
|
||||
logger.debug("Set preview trigger: %s", trigger)
|
||||
|
||||
def clear(self, trigger_type: Optional[Literal["update", "mask_toggle"]] = None) -> None:
|
||||
def clear(self, trigger_type: T.Literal["update", "mask_toggle"] | None = None) -> None:
|
||||
""" Remove the trigger file from the cache folder.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from threading import Event, Thread
|
||||
from typing import (Any, Callable, cast, Dict, Optional, Tuple, Type, TYPE_CHECKING)
|
||||
from queue import Queue
|
||||
|
||||
from .config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from types import TracebackType
|
||||
from lib.multithreading import _ErrorType
|
||||
|
||||
|
@ -31,15 +33,15 @@ class LongRunningTask(Thread):
|
|||
cursor in the correct location. Default: ``None``.
|
||||
"""
|
||||
_target: Callable
|
||||
_args: Tuple
|
||||
_kwargs: Dict[str, Any]
|
||||
_args: tuple
|
||||
_kwargs: dict[str, T.Any]
|
||||
_name: str
|
||||
|
||||
def __init__(self,
|
||||
target: Optional[Callable] = None,
|
||||
name: Optional[str] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
target: Callable | None = None,
|
||||
name: str | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, T.Any] | None = None,
|
||||
*,
|
||||
daemon: bool = True,
|
||||
widget=None):
|
||||
|
@ -48,7 +50,7 @@ class LongRunningTask(Thread):
|
|||
daemon)
|
||||
super().__init__(target=target, name=name, args=args, kwargs=kwargs,
|
||||
daemon=daemon)
|
||||
self.err: "_ErrorType" = None
|
||||
self.err: _ErrorType = None
|
||||
self._widget = widget
|
||||
self._config = get_config()
|
||||
self._config.set_cursor_busy(widget=self._widget)
|
||||
|
@ -70,8 +72,8 @@ class LongRunningTask(Thread):
|
|||
retval = self._target(*self._args, **self._kwargs)
|
||||
self._queue.put(retval)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.err = cast(Tuple[Type[BaseException], BaseException, "TracebackType"],
|
||||
sys.exc_info())
|
||||
self.err = T.cast(tuple[type[BaseException], BaseException, "TracebackType"],
|
||||
sys.exc_info())
|
||||
assert self.err is not None
|
||||
logger.debug("Error in thread (%s): %s", self._name,
|
||||
self.err[1].with_traceback(self.err[2]))
|
||||
|
@ -81,7 +83,7 @@ class LongRunningTask(Thread):
|
|||
# an argument that has a member that points to the thread.
|
||||
del self._target, self._args, self._kwargs
|
||||
|
||||
def get_result(self) -> Any:
|
||||
def get_result(self) -> T.Any:
|
||||
""" Return the result from the given task.
|
||||
|
||||
Returns
|
||||
|
|
16
lib/image.py
16
lib/image.py
|
@ -1,17 +1,17 @@
|
|||
#!/usr/bin python3
|
||||
""" Utilities for working with images and videos """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from ast import literal_eval
|
||||
from bisect import bisect
|
||||
from concurrent import futures
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from zlib import crc32
|
||||
|
||||
import cv2
|
||||
|
@ -24,7 +24,7 @@ from lib.multithreading import MultiThread
|
|||
from lib.queue_manager import queue_manager, QueueEmpty
|
||||
from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.alignments import PNGHeaderDict
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
|
@ -558,7 +558,7 @@ def update_existing_metadata(filename, metadata):
|
|||
|
||||
def encode_image(image: np.ndarray,
|
||||
extension: str,
|
||||
metadata: Optional["PNGHeaderDict"] = None) -> bytes:
|
||||
metadata: PNGHeaderDict | None = None) -> bytes:
|
||||
""" Encode an image.
|
||||
|
||||
Parameters
|
||||
|
@ -1433,8 +1433,8 @@ class ImagesSaver(ImageIO):
|
|||
|
||||
def _save(self,
|
||||
filename: str,
|
||||
image: Union[bytes, np.ndarray],
|
||||
sub_folder: Optional[str]) -> None:
|
||||
image: bytes | np.ndarray,
|
||||
sub_folder: str | None) -> None:
|
||||
""" Save a single image inside a ThreadPoolExecutor
|
||||
|
||||
Parameters
|
||||
|
@ -1468,8 +1468,8 @@ class ImagesSaver(ImageIO):
|
|||
|
||||
def save(self,
|
||||
filename: str,
|
||||
image: Union[bytes, np.ndarray],
|
||||
sub_folder: Optional[str] = None) -> None:
|
||||
image: bytes | np.ndarray,
|
||||
sub_folder: str | None = None) -> None:
|
||||
""" Save the given image in the background thread
|
||||
|
||||
Ensure that :func:`close` is called once all save operations are complete.
|
||||
|
|
|
@ -117,7 +117,7 @@ class ColorSpaceConvert(): # pylint:disable=too-few-public-methods
|
|||
self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32")
|
||||
|
||||
@classmethod
|
||||
def _get_rgb_xyz_map(cls) -> T.Tuple[Tensor, Tensor]:
|
||||
def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]:
|
||||
""" Obtain the mapping and inverse mapping for rgb to xyz color space conversion.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -11,7 +11,17 @@ import time
|
|||
import traceback
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
|
||||
# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
|
||||
def _patched_format(self, record):
|
||||
""" Autograph tf-2.10 has a bug with the 3.10 version of logging.PercentStyle._format(). It is
|
||||
non-critical but spits out warnings. This is the Python 3.9 version of the function and should
|
||||
be removed once fixed """
|
||||
return self._fmt % record.__dict__ # pylint:disable=protected-access
|
||||
|
||||
|
||||
setattr(logging.PercentStyle, "_format", _patched_format)
|
||||
|
||||
|
||||
class FaceswapLogger(logging.Logger):
|
||||
|
@ -76,11 +86,11 @@ class ColoredFormatter(logging.Formatter):
|
|||
def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None:
|
||||
super().__init__(fmt, **kwargs)
|
||||
self._use_color = self._get_color_compatibility()
|
||||
self._level_colors = dict(CRITICAL="\033[31m", # red
|
||||
ERROR="\033[31m", # red
|
||||
WARNING="\033[33m", # yellow
|
||||
INFO="\033[32m", # green
|
||||
VERBOSE="\033[34m") # blue
|
||||
self._level_colors = {"CRITICAL": "\033[31m", # red
|
||||
"ERROR": "\033[31m", # red
|
||||
"WARNING": "\033[33m", # yellow
|
||||
"INFO": "\033[32m", # green
|
||||
"VERBOSE": "\033[34m"} # blue
|
||||
self._default_color = "\033[0m"
|
||||
self._newline_padding = self._get_newline_padding(pad_newlines, fmt)
|
||||
|
||||
|
@ -412,7 +422,7 @@ def _file_handler(loglevel,
|
|||
return handler
|
||||
|
||||
|
||||
def _stream_handler(loglevel: int, is_gui: bool) -> Union[logging.StreamHandler, TqdmHandler]:
|
||||
def _stream_handler(loglevel: int, is_gui: bool) -> logging.StreamHandler | TqdmHandler:
|
||||
""" Add a stream handler for the current Faceswap session. The stream handler will only ever
|
||||
output at a maximum of VERBOSE level to avoid spamming the console.
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
""" Auto clipper for clipping gradients. """
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
class AutoClipper(): # pylint:disable=too-few-public-methods
|
||||
|
@ -22,12 +20,56 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
|
|||
original paper: https://arxiv.org/abs/2007.14469
|
||||
"""
|
||||
def __init__(self, clip_percentile: int, history_size: int = 10000):
|
||||
self._clip_percentile = clip_percentile
|
||||
self._clip_percentile = tf.cast(clip_percentile, tf.float64)
|
||||
self._grad_history = tf.Variable(tf.zeros(history_size), trainable=False)
|
||||
self._index = tf.Variable(0, trainable=False)
|
||||
self._history_size = history_size
|
||||
|
||||
def __call__(self, grads_and_vars: List[tf.Tensor]) -> List[tf.Tensor]:
|
||||
def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor:
|
||||
""" Compute the clip percentile of the gradient history
|
||||
|
||||
Parameters
|
||||
----------
|
||||
grad_history: :class:`tensorflow.Tensor`
|
||||
Tge gradient history to calculate the clip percentile for
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`tensorflow.Tensor`
|
||||
A rank(:attr:`clip_percentile`) `Tensor`
|
||||
|
||||
Notes
|
||||
-----
|
||||
Adapted from
|
||||
https://github.com/tensorflow/probability/blob/r0.14/tensorflow_probability/python/stats/quantiles.py
|
||||
to remove reliance on full tensorflow_probability libraray
|
||||
"""
|
||||
with tf.name_scope("percentile"):
|
||||
frac_at_q_or_below = self._clip_percentile / 100.
|
||||
sorted_hist = tf.sort(grad_history, axis=-1, direction="ASCENDING")
|
||||
|
||||
num = tf.cast(tf.shape(grad_history)[-1], tf.float64)
|
||||
|
||||
# get indices
|
||||
indices = tf.round((num - 1) * frac_at_q_or_below)
|
||||
indices = tf.clip_by_value(tf.cast(indices, tf.int32),
|
||||
0,
|
||||
tf.shape(grad_history)[-1] - 1)
|
||||
gathered_hist = tf.gather(sorted_hist, indices, axis=-1)
|
||||
|
||||
# Propagate NaNs. Apparently tf.is_nan doesn't like other dtypes
|
||||
nan_batch_members = tf.reduce_any(tf.math.is_nan(grad_history), axis=None)
|
||||
right_rank_matched_shape = tf.pad(tf.shape(nan_batch_members),
|
||||
paddings=[[0, tf.rank(self._clip_percentile)]],
|
||||
constant_values=1)
|
||||
nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape)
|
||||
|
||||
nan = np.array(np.nan, gathered_hist.dtype.as_numpy_dtype)
|
||||
gathered_hist = tf.where(nan_batch_members, nan, gathered_hist)
|
||||
|
||||
return gathered_hist
|
||||
|
||||
def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]:
|
||||
""" Call the AutoClip function.
|
||||
|
||||
Parameters
|
||||
|
@ -40,8 +82,7 @@ class AutoClipper(): # pylint:disable=too-few-public-methods
|
|||
assign_idx = tf.math.mod(self._index, self._history_size)
|
||||
self._grad_history = self._grad_history[assign_idx].assign(total_norm)
|
||||
self._index = self._index.assign_add(1)
|
||||
clip_value = tfp.stats.percentile(self._grad_history[: self._index],
|
||||
q=self._clip_percentile)
|
||||
clip_value = self._percentile(self._grad_history[: self._index])
|
||||
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Custom Feature Map Loss Functions for faceswap.py """
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, List, Tuple
|
||||
import typing as T
|
||||
|
||||
# Ignore linting errors from Tensorflow's thoroughly broken import system
|
||||
import tensorflow as tf
|
||||
|
@ -17,6 +17,9 @@ import numpy as np
|
|||
from lib.model.nets import AlexNet, SqueezeNet
|
||||
from lib.utils import GetModel
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -39,10 +42,10 @@ class NetInfo:
|
|||
"""
|
||||
model_id: int = 0
|
||||
model_name: str = ""
|
||||
net: Optional[Callable] = None
|
||||
init_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
net: Callable | None = None
|
||||
init_kwargs: dict[str, T.Any] = field(default_factory=dict)
|
||||
needs_init: bool = True
|
||||
outputs: List[Layer] = field(default_factory=list)
|
||||
outputs: list[Layer] = field(default_factory=list)
|
||||
|
||||
|
||||
class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
|
||||
|
@ -67,7 +70,7 @@ class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initialized: %s ", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def _nets(self) -> Dict[str, NetInfo]:
|
||||
def _nets(self) -> dict[str, NetInfo]:
|
||||
""" :class:`NetInfo`: The Information about the requested net."""
|
||||
return {
|
||||
"alex": NetInfo(model_id=15,
|
||||
|
@ -176,7 +179,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def _nets(self) -> Dict[str, NetInfo]:
|
||||
def _nets(self) -> dict[str, NetInfo]:
|
||||
""" :class:`NetInfo`: The Information about the requested net."""
|
||||
return {
|
||||
"alex": NetInfo(model_id=18,
|
||||
|
@ -186,7 +189,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods
|
|||
"vgg16": NetInfo(model_id=20,
|
||||
model_name="vgg16_lpips_v1.h5")}
|
||||
|
||||
def _linear_block(self, net_output_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
||||
""" Build a linear block for a trunk network output.
|
||||
|
||||
Parameters
|
||||
|
@ -319,7 +322,7 @@ class LPIPSLoss(): # pylint:disable=too-few-public-methods
|
|||
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def _process_diffs(self, inputs: List[tf.Tensor]) -> List[tf.Tensor]:
|
||||
def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]:
|
||||
""" Perform processing on the Trunk Network outputs.
|
||||
|
||||
If :attr:`use_ldip` is enabled, process the diff values through the linear network,
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Custom Loss Functions for faceswap.py """
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import Callable, List, Tuple
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -13,6 +12,9 @@ import tensorflow as tf
|
|||
from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module
|
||||
from tensorflow.keras import backend as K # pylint:disable=import-error
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -61,7 +63,7 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods
|
|||
self._ave_spectrum = ave_spectrum
|
||||
self._log_matrix = log_matrix
|
||||
self._batch_matrix = batch_matrix
|
||||
self._dims: Tuple[int, int] = (0, 0)
|
||||
self._dims: tuple[int, int] = (0, 0)
|
||||
|
||||
def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor:
|
||||
""" Crop the incoming batch of images into patches as defined by :attr:`_patch_factor.
|
||||
|
@ -470,7 +472,7 @@ class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods
|
|||
retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
|
||||
return retval
|
||||
|
||||
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]:
|
||||
def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> list[tf.Tensor]:
|
||||
""" Obtain the Laplacian Pyramid.
|
||||
|
||||
Parameters
|
||||
|
@ -564,9 +566,9 @@ class LossWrapper(tf.keras.losses.Loss):
|
|||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
super().__init__(name="LossWrapper")
|
||||
self._loss_functions: List[compile_utils.LossesContainer] = []
|
||||
self._loss_weights: List[float] = []
|
||||
self._mask_channels: List[int] = []
|
||||
self._loss_functions: list[compile_utils.LossesContainer] = []
|
||||
self._loss_weights: list[float] = []
|
||||
self._mask_channels: list[int] = []
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def add_loss(self,
|
||||
|
@ -628,7 +630,7 @@ class LossWrapper(tf.keras.losses.Loss):
|
|||
y_true: tf.Tensor,
|
||||
y_pred: tf.Tensor,
|
||||
mask_channel: int,
|
||||
mask_prop: float = 1.0) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
mask_prop: float = 1.0) -> tuple[tf.Tensor, tf.Tensor]:
|
||||
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
|
||||
return the unmasked inputs.
|
||||
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
""" TF Keras implementation of Perceptual Loss Functions for faceswap.py """
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -14,11 +12,6 @@ from tensorflow.keras import backend as K # pylint:disable=import-error
|
|||
|
||||
from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -101,7 +94,7 @@ class DSSIMObjective(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
return K.depthwise_conv2d(image, kernel, strides=(1, 1), padding="valid")
|
||||
|
||||
def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
|
||||
""" Obtain the structural similarity between a batch of true and predicted images.
|
||||
|
||||
Parameters
|
||||
|
@ -330,8 +323,8 @@ class LDRFLIPLoss(): # pylint:disable=too-few-public-methods
|
|||
lower_threshold_exponent: float = 0.4,
|
||||
upper_threshold_exponent: float = 0.95,
|
||||
epsilon: float = 1e-15,
|
||||
pixels_per_degree: Optional[float] = None,
|
||||
color_order: Literal["bgr", "rgb"] = "bgr") -> None:
|
||||
pixels_per_degree: float | None = None,
|
||||
color_order: T.Literal["bgr", "rgb"] = "bgr") -> None:
|
||||
logger.debug("Initializing: %s (computed_distance_exponent '%s', feature_exponent: %s, "
|
||||
"lower_threshold_exponent: %s, upper_threshold_exponent: %s, epsilon: %s, "
|
||||
"pixels_per_degree: %s, color_order: %s)", self.__class__.__name__,
|
||||
|
@ -525,7 +518,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
|
|||
self._spatial_filters, self._radius = self._generate_spatial_filters()
|
||||
self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb")
|
||||
|
||||
def _generate_spatial_filters(self) -> Tuple[tf.Tensor, int]:
|
||||
def _generate_spatial_filters(self) -> tuple[tf.Tensor, int]:
|
||||
""" Generates spatial contrast sensitivity filters with width depending on the number of
|
||||
pixels per degree of visual angle of the observer for channels "A", "RG" and "BY"
|
||||
|
||||
|
@ -559,7 +552,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
|
|||
b1_rg: float,
|
||||
b2_rg: float,
|
||||
b1_by: float,
|
||||
b2_by: float) -> Tuple[np.ndarray, int]:
|
||||
b2_by: float) -> tuple[np.ndarray, int]:
|
||||
""" TODO docstring """
|
||||
max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by])
|
||||
delta_x = 1.0 / self._pixels_per_degree
|
||||
|
@ -570,7 +563,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods
|
|||
return domain, radius
|
||||
|
||||
@classmethod
|
||||
def _generate_weights(cls, channel: Dict[str, float], domain: np.ndarray) -> tf.Tensor:
|
||||
def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> tf.Tensor:
|
||||
""" TODO docstring """
|
||||
a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"]
|
||||
grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) +
|
||||
|
@ -694,7 +687,7 @@ class MSSIMLoss(): # pylint:disable=too-few-public-methods
|
|||
filter_size: int = 11,
|
||||
filter_sigma: float = 1.5,
|
||||
max_value: float = 1.0,
|
||||
power_factors: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
|
||||
power_factors: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
|
||||
) -> None:
|
||||
self.filter_size = filter_size
|
||||
self.filter_sigma = filter_sigma
|
||||
|
|
|
@ -31,7 +31,7 @@ class _net(): # pylint:disable=too-few-public-methods
|
|||
The input shape for the model. Default: ``None``
|
||||
"""
|
||||
def __init__(self,
|
||||
input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None:
|
||||
input_shape: tuple[int, int, int] | None = None) -> None:
|
||||
logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape)
|
||||
self._input_shape = (None, None, 3) if input_shape is None else input_shape
|
||||
assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, (
|
||||
|
@ -56,7 +56,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
|
|||
input_shape, Tuple, optional
|
||||
The input shape for the model. Default: ``None``
|
||||
"""
|
||||
def __init__(self, input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None:
|
||||
def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None:
|
||||
super().__init__(input_shape)
|
||||
self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch
|
||||
self._filters = [64, 192, 384, 256, 256] # Filters at each block
|
||||
|
@ -108,7 +108,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods
|
|||
name=name)(var_x)
|
||||
return var_x
|
||||
|
||||
def __call__(self) -> Model:
|
||||
def __call__(self) -> tf.keras.models.Model:
|
||||
""" Create the AlexNet Model
|
||||
|
||||
Returns
|
||||
|
@ -189,7 +189,7 @@ class SqueezeNet(_net): # pylint:disable=too-few-public-methods
|
|||
name=f"{name}.expand3x3")(squeezed)
|
||||
return layers.Concatenate(axis=-1, name=name)([expand1, expand3])
|
||||
|
||||
def __call__(self) -> Model:
|
||||
def __call__(self) -> tf.keras.models.Model:
|
||||
""" Create the SqueezeNet Model
|
||||
|
||||
Returns
|
||||
|
|
|
@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|||
|
||||
|
||||
_CONFIG: dict = {}
|
||||
_NAMES: T.Dict[str, int] = {}
|
||||
_NAMES: dict[str, int] = {}
|
||||
|
||||
|
||||
def set_config(configuration: dict) -> None:
|
||||
|
@ -189,7 +189,7 @@ class Conv2DOutput(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int]],
|
||||
kernel_size: int | tuple[int],
|
||||
activation: str = "sigmoid",
|
||||
padding: str = "same", **kwargs) -> None:
|
||||
self._name = kwargs.pop("name") if "name" in kwargs else _get_name(
|
||||
|
@ -265,11 +265,11 @@ class Conv2DBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 5,
|
||||
strides: T.Union[int, T.Tuple[int, int]] = 2,
|
||||
kernel_size: int | tuple[int, int] = 5,
|
||||
strides: int | tuple[int, int] = 2,
|
||||
padding: str = "same",
|
||||
normalization: T.Optional[str] = None,
|
||||
activation: T.Optional[str] = "leakyrelu",
|
||||
normalization: str | None = None,
|
||||
activation: str | None = "leakyrelu",
|
||||
use_depthwise: bool = False,
|
||||
relu_alpha: float = 0.1,
|
||||
**kwargs) -> None:
|
||||
|
@ -362,8 +362,8 @@ class SeparableConv2DBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 5,
|
||||
strides: T.Union[int, T.Tuple[int, int]] = 2, **kwargs) -> None:
|
||||
kernel_size: int | tuple[int, int] = 5,
|
||||
strides: int | tuple[int, int] = 2, **kwargs) -> None:
|
||||
self._name = _get_name(f"separableconv2d_{filters}")
|
||||
logger.debug("name: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)",
|
||||
self._name, filters, kernel_size, strides, kwargs)
|
||||
|
@ -434,11 +434,11 @@ class UpscaleBlock(): # pylint:disable=too-few-public-methods
|
|||
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
|
||||
kernel_size: int | tuple[int, int] = 3,
|
||||
padding: str = "same",
|
||||
scale_factor: int = 2,
|
||||
normalization: T.Optional[str] = None,
|
||||
activation: T.Optional[str] = "leakyrelu",
|
||||
normalization: str | None = None,
|
||||
activation: str | None = "leakyrelu",
|
||||
**kwargs) -> None:
|
||||
self._name = _get_name(f"upscale_{filters}")
|
||||
logger.debug("name: %s. filters: %s, kernel_size: %s, padding: %s, scale_factor: %s, "
|
||||
|
@ -521,9 +521,9 @@ class Upscale2xBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
|
||||
kernel_size: int | tuple[int, int] = 3,
|
||||
padding: str = "same",
|
||||
activation: T.Optional[str] = "leakyrelu",
|
||||
activation: str | None = "leakyrelu",
|
||||
interpolation: str = "bilinear",
|
||||
sr_ratio: float = 0.5,
|
||||
scale_factor: int = 2,
|
||||
|
@ -615,9 +615,9 @@ class UpscaleResizeImagesBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
|
||||
kernel_size: int | tuple[int, int] = 3,
|
||||
padding: str = "same",
|
||||
activation: T.Optional[str] = "leakyrelu",
|
||||
activation: str | None = "leakyrelu",
|
||||
scale_factor: int = 2,
|
||||
interpolation: str = "bilinear") -> None:
|
||||
self._name = _get_name(f"upscale_ri_{filters}")
|
||||
|
@ -700,9 +700,9 @@ class UpscaleDNYBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
|
||||
kernel_size: int | tuple[int, int] = 3,
|
||||
padding: str = "same",
|
||||
activation: T.Optional[str] = "leakyrelu",
|
||||
activation: str | None = "leakyrelu",
|
||||
size: int = 2,
|
||||
interpolation: str = "bilinear",
|
||||
**kwargs) -> None:
|
||||
|
@ -757,7 +757,7 @@ class ResidualBlock(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
filters: int,
|
||||
kernel_size: T.Union[int, T.Tuple[int, int]] = 3,
|
||||
kernel_size: int | tuple[int, int] = 3,
|
||||
padding: str = "same",
|
||||
**kwargs) -> None:
|
||||
self._name = _get_name(f"residual_{filters}")
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#!/usr/bin python3
|
||||
""" Settings manager for Keras Backend """
|
||||
|
||||
from __future__ import annotations
|
||||
from contextlib import nullcontext
|
||||
import logging
|
||||
from typing import Callable, ContextManager, List, Optional, Union
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -14,6 +14,9 @@ from tensorflow.keras.models import load_model as k_load_model, Model # noqa:E5
|
|||
|
||||
from lib.utils import get_backend
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint:disable=invalid-name
|
||||
|
||||
|
||||
|
@ -52,9 +55,9 @@ class KSession():
|
|||
def __init__(self,
|
||||
name: str,
|
||||
model_path: str,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
model_kwargs: dict | None = None,
|
||||
allow_growth: bool = False,
|
||||
exclude_gpus: Optional[List[int]] = None,
|
||||
exclude_gpus: list[int] | None = None,
|
||||
cpu_mode: bool = False) -> None:
|
||||
logger.trace("Initializing: %s (name: %s, model_path: %s, " # type:ignore
|
||||
"model_kwargs: %s, allow_growth: %s, exclude_gpus: %s, cpu_mode: %s)",
|
||||
|
@ -67,12 +70,12 @@ class KSession():
|
|||
cpu_mode)
|
||||
self._model_path = model_path
|
||||
self._model_kwargs = {} if not model_kwargs else model_kwargs
|
||||
self._model: Optional[Model] = None
|
||||
self._model: Model | None = None
|
||||
logger.trace("Initialized: %s", self.__class__.__name__,) # type:ignore
|
||||
|
||||
def predict(self,
|
||||
feed: Union[List[np.ndarray], np.ndarray],
|
||||
batch_size: Optional[int] = None) -> Union[List[np.ndarray], np.ndarray]:
|
||||
feed: list[np.ndarray] | np.ndarray,
|
||||
batch_size: int | None = None) -> list[np.ndarray] | np.ndarray:
|
||||
""" Get predictions from the model.
|
||||
|
||||
This method is a wrapper for :func:`keras.predict()` function. For Tensorflow backends
|
||||
|
@ -98,7 +101,7 @@ class KSession():
|
|||
def _set_session(self,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: list,
|
||||
cpu_mode: bool) -> ContextManager:
|
||||
cpu_mode: bool) -> T.ContextManager:
|
||||
""" Sets the backend session options.
|
||||
|
||||
For CPU backends, this hides any GPUs from Tensorflow.
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Multithreading/processing utils for faceswap """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as T
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import queue as Queue
|
||||
import sys
|
||||
import threading
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Dict, Generator, List, Tuple, Type, Optional, Set, Union
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
_ErrorType = Optional[Union[Tuple[Type[BaseException], BaseException, TracebackType],
|
||||
Tuple[Any, Any, Any]]]
|
||||
_THREAD_NAMES: Set[str] = set()
|
||||
_ErrorType: T.TypeAlias = (tuple[type[BaseException], BaseException, TracebackType] |
|
||||
tuple[T.Any, T.Any, T.Any] | None)
|
||||
_THREAD_NAMES: set[str] = set()
|
||||
|
||||
|
||||
def total_cpus():
|
||||
|
@ -62,17 +65,17 @@ class FSThread(threading.Thread):
|
|||
keyword arguments for the target invocation. Default: {}.
|
||||
"""
|
||||
_target: Callable
|
||||
_args: Tuple
|
||||
_kwargs: Dict[str, Any]
|
||||
_args: tuple
|
||||
_kwargs: dict[str, T.Any]
|
||||
_name: str
|
||||
|
||||
def __init__(self,
|
||||
target: Optional[Callable] = None,
|
||||
name: Optional[str] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
target: Callable | None = None,
|
||||
name: str | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict[str, T.Any] | None = None,
|
||||
*,
|
||||
daemon: Optional[bool] = None) -> None:
|
||||
daemon: bool | None = None) -> None:
|
||||
super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
|
||||
self.err: _ErrorType = None
|
||||
|
||||
|
@ -124,7 +127,7 @@ class MultiThread():
|
|||
target: Callable,
|
||||
*args,
|
||||
thread_count: int = 1,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
**kwargs) -> None:
|
||||
self._name = _get_name(name if name else target.__name__)
|
||||
logger.debug("Initializing %s: (target: '%s', thread_count: %s)",
|
||||
|
@ -132,7 +135,7 @@ class MultiThread():
|
|||
logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore
|
||||
self.daemon = True
|
||||
self._thread_count = thread_count
|
||||
self._threads: List[FSThread] = []
|
||||
self._threads: list[FSThread] = []
|
||||
self._target = target
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
@ -144,7 +147,7 @@ class MultiThread():
|
|||
return any(thread.err for thread in self._threads)
|
||||
|
||||
@property
|
||||
def errors(self) -> List[_ErrorType]:
|
||||
def errors(self) -> list[_ErrorType]:
|
||||
""" list: List of thread error values """
|
||||
return [thread.err for thread in self._threads if thread.err]
|
||||
|
||||
|
@ -253,9 +256,9 @@ class BackgroundGenerator(MultiThread):
|
|||
def __init__(self,
|
||||
generator: Callable,
|
||||
prefetch: int = 1,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[Tuple] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
name: str | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, T.Any] | None = None) -> None:
|
||||
super().__init__(name=name, target=self._run)
|
||||
self.queue: Queue.Queue = Queue.Queue(prefetch)
|
||||
self.generator = generator
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa
|
||||
from time import sleep
|
||||
|
@ -45,7 +44,7 @@ class _QueueManager():
|
|||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
|
||||
self.shutdown = threading.Event()
|
||||
self.queues: Dict[str, EventQueue] = {}
|
||||
self.queues: dict[str, EventQueue] = {}
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str:
|
||||
|
|
|
@ -6,8 +6,8 @@ import locale
|
|||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from subprocess import PIPE, Popen
|
||||
from typing import List, Optional
|
||||
|
||||
import psutil
|
||||
|
||||
|
@ -21,14 +21,14 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
|
|||
def __init__(self) -> None:
|
||||
self._state_file = _State().state_file
|
||||
self._configs = _Configs().configs
|
||||
self._system = dict(platform=platform.platform(),
|
||||
system=platform.system().lower(),
|
||||
machine=platform.machine(),
|
||||
release=platform.release(),
|
||||
processor=platform.processor(),
|
||||
cpu_count=os.cpu_count())
|
||||
self._python = dict(implementation=platform.python_implementation(),
|
||||
version=platform.python_version())
|
||||
self._system = {"platform": platform.platform(),
|
||||
"system": platform.system().lower(),
|
||||
"machine": platform.machine(),
|
||||
"release": platform.release(),
|
||||
"processor": platform.processor(),
|
||||
"cpu_count": os.cpu_count()}
|
||||
self._python = {"implementation": platform.python_implementation(),
|
||||
"version": platform.python_version()}
|
||||
self._gpu = self._get_gpu_info()
|
||||
self._cuda_check = CudaCheck()
|
||||
|
||||
|
@ -66,7 +66,7 @@ class _SysInfo(): # pylint:disable=too-few-public-methods
|
|||
(hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix))
|
||||
else:
|
||||
prefix = os.path.dirname(sys.prefix)
|
||||
retval = (os.path.basename(prefix) == "envs")
|
||||
retval = os.path.basename(prefix) == "envs"
|
||||
return retval
|
||||
|
||||
@property
|
||||
|
@ -295,7 +295,7 @@ class _Configs(): # pylint:disable=too-few-public-methods
|
|||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
def _parse_configs(self, config_files: List[str]) -> str:
|
||||
def _parse_configs(self, config_files: list[str]) -> str:
|
||||
""" Parse the given list of config files into a human readable format.
|
||||
|
||||
Parameters
|
||||
|
@ -399,7 +399,7 @@ class _State(): # pylint:disable=too-few-public-methods
|
|||
return len(sys.argv) > 1 and sys.argv[1].lower() == "train"
|
||||
|
||||
@staticmethod
|
||||
def _get_arg(*args: str) -> Optional[str]:
|
||||
def _get_arg(*args: str) -> str | None:
|
||||
""" Obtain the value for a given command line option from sys.argv.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Package for handling alignments files, detected faces and aligned faces along with their
|
||||
associated objects. """
|
||||
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
import typing as T
|
||||
|
||||
from .augmentation import ImageAugmentation
|
||||
from .generator import PreviewDataGenerator, TrainingDataGenerator
|
||||
from .preview_cv import PreviewBuffer, TriggerType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from .preview_cv import PreviewBase
|
||||
Preview: Type[PreviewBase]
|
||||
Preview: type[PreviewBase]
|
||||
|
||||
try:
|
||||
from .preview_tk import PreviewTk as Preview
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Processes the augmentation of images for feeding into a Faceswap model. """
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Dict, Tuple, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numexpr as ne
|
||||
|
@ -11,7 +12,7 @@ from scipy.interpolate import griddata
|
|||
|
||||
from lib.image import batch_convert_color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.config import ConfigValueType
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
@ -56,7 +57,7 @@ class AugConstants:
|
|||
transform_zoom: float
|
||||
transform_shift: float
|
||||
warp_maps: np.ndarray
|
||||
warp_pad: Tuple[int, int]
|
||||
warp_pad: tuple[int, int]
|
||||
warp_slices: slice
|
||||
warp_lm_edge_anchors: np.ndarray
|
||||
warp_lm_grids: np.ndarray
|
||||
|
@ -79,7 +80,7 @@ class ImageAugmentation():
|
|||
def __init__(self,
|
||||
batchsize: int,
|
||||
processing_size: int,
|
||||
config: Dict[str, "ConfigValueType"]) -> None:
|
||||
config: dict[str, ConfigValueType]) -> None:
|
||||
logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, "
|
||||
"config: %s)",
|
||||
self.__class__.__name__, batchsize, processing_size, config)
|
||||
|
@ -332,7 +333,7 @@ class ImageAugmentation():
|
|||
slices = self._constants.warp_slices
|
||||
rands = np.random.normal(size=(self._batchsize, 2, 5, 5),
|
||||
scale=self._warp_scale).astype("float32")
|
||||
batch_maps = ne.evaluate("m + r", local_dict=dict(m=self._constants.warp_maps, r=rands))
|
||||
batch_maps = ne.evaluate("m + r", local_dict={"m": self._constants.warp_maps, "r": rands})
|
||||
batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices]
|
||||
for map_ in maps]
|
||||
for maps in batch_maps])
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Holds the data cache for training data generators """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from threading import Lock
|
||||
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -16,25 +16,19 @@ from lib.align.aligned_face import CenteringType
|
|||
from lib.image import read_image_batch, read_image_meta_batch
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict
|
||||
from lib.config import ConfigValueType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FACE_CACHES: Dict[str, "_Cache"] = {}
|
||||
_FACE_CACHES: dict[str, "_Cache"] = {}
|
||||
|
||||
|
||||
def get_cache(side: Literal["a", "b"],
|
||||
filenames: Optional[List[str]] = None,
|
||||
config: Optional[Dict[str, "ConfigValueType"]] = None,
|
||||
size: Optional[int] = None,
|
||||
coverage_ratio: Optional[float] = None) -> "_Cache":
|
||||
def get_cache(side: T.Literal["a", "b"],
|
||||
filenames: list[str] | None = None,
|
||||
config: dict[str, ConfigValueType] | None = None,
|
||||
size: int | None = None,
|
||||
coverage_ratio: float | None = None) -> "_Cache":
|
||||
""" Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then
|
||||
create it.
|
||||
|
||||
|
@ -120,24 +114,24 @@ class _Cache():
|
|||
The coverage ratio that the model is using.
|
||||
"""
|
||||
def __init__(self,
|
||||
filenames: List[str],
|
||||
config: Dict[str, "ConfigValueType"],
|
||||
filenames: list[str],
|
||||
config: dict[str, ConfigValueType],
|
||||
size: int,
|
||||
coverage_ratio: float) -> None:
|
||||
logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)",
|
||||
self.__class__.__name__, len(filenames), size, coverage_ratio)
|
||||
self._lock = Lock()
|
||||
self._cache_info = dict(cache_full=False, has_reset=False)
|
||||
self._partially_loaded: List[str] = []
|
||||
self._cache_info = {"cache_full": False, "has_reset": False}
|
||||
self._partially_loaded: list[str] = []
|
||||
|
||||
self._image_count = len(filenames)
|
||||
self._cache: Dict[str, DetectedFace] = {}
|
||||
self._aligned_landmarks: Dict[str, np.ndarray] = {}
|
||||
self._cache: dict[str, DetectedFace] = {}
|
||||
self._aligned_landmarks: dict[str, np.ndarray] = {}
|
||||
self._extract_version = 0.0
|
||||
self._size = size
|
||||
|
||||
assert config["centering"] in get_args(CenteringType)
|
||||
self._centering: CenteringType = cast(CenteringType, config["centering"])
|
||||
assert config["centering"] in T.get_args(CenteringType)
|
||||
self._centering: CenteringType = T.cast(CenteringType, config["centering"])
|
||||
self._config = config
|
||||
self._coverage_ratio = coverage_ratio
|
||||
|
||||
|
@ -153,7 +147,7 @@ class _Cache():
|
|||
return self._cache_info["cache_full"]
|
||||
|
||||
@property
|
||||
def aligned_landmarks(self) -> Dict[str, np.ndarray]:
|
||||
def aligned_landmarks(self) -> dict[str, np.ndarray]:
|
||||
""" dict: The filename as key, aligned landmarks as value. """
|
||||
# Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate
|
||||
# all of the aligned landmarks for the entire cache.
|
||||
|
@ -185,7 +179,7 @@ class _Cache():
|
|||
self._cache_info["has_reset"] = False
|
||||
return retval
|
||||
|
||||
def get_items(self, filenames: List[str]) -> List[DetectedFace]:
|
||||
def get_items(self, filenames: list[str]) -> list[DetectedFace]:
|
||||
""" Obtain the cached items for a list of filenames. The returned list is in the same order
|
||||
as the provided filenames.
|
||||
|
||||
|
@ -202,7 +196,7 @@ class _Cache():
|
|||
"""
|
||||
return [self._cache[os.path.basename(filename)] for filename in filenames]
|
||||
|
||||
def cache_metadata(self, filenames: List[str]) -> np.ndarray:
|
||||
def cache_metadata(self, filenames: list[str]) -> np.ndarray:
|
||||
""" Obtain the batch with metadata for items that need caching and cache DetectedFace
|
||||
objects to :attr:`_cache`.
|
||||
|
||||
|
@ -267,7 +261,7 @@ class _Cache():
|
|||
|
||||
return batch
|
||||
|
||||
def pre_fill(self, filenames: List[str], side: Literal["a", "b"]) -> None:
|
||||
def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None:
|
||||
""" When warp to landmarks is enabled, the cache must be pre-filled, as each side needs
|
||||
access to the other side's alignments.
|
||||
|
||||
|
@ -294,7 +288,7 @@ class _Cache():
|
|||
self._cache[key] = detected_face
|
||||
self._partially_loaded.append(key)
|
||||
|
||||
def _validate_version(self, png_meta: "PNGHeaderDict", filename: str) -> None:
|
||||
def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None:
|
||||
""" Validate that there are not a mix of v1.0 extracted faces and v2.x faces.
|
||||
|
||||
Parameters
|
||||
|
@ -350,7 +344,7 @@ class _Cache():
|
|||
|
||||
def _load_detected_face(self,
|
||||
filename: str,
|
||||
alignments: "PNGHeaderAlignmentsDict") -> DetectedFace:
|
||||
alignments: PNGHeaderAlignmentsDict) -> DetectedFace:
|
||||
""" Load a :class:`DetectedFace` object and load its associated `aligned` property.
|
||||
|
||||
Parameters
|
||||
|
@ -387,13 +381,13 @@ class _Cache():
|
|||
The detected face object that holds the masks
|
||||
"""
|
||||
masks = [(self._get_face_mask(filename, detected_face))]
|
||||
for area in get_args(Literal["eye", "mouth"]):
|
||||
for area in T.get_args(T.Literal["eye", "mouth"]):
|
||||
masks.append(self._get_localized_mask(filename, detected_face, area))
|
||||
|
||||
detected_face.store_training_masks(masks, delete_masks=True)
|
||||
logger.trace("Stored masks for filename: %s)", filename) # type: ignore
|
||||
|
||||
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> Optional[np.ndarray]:
|
||||
def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None:
|
||||
""" Obtain the training sized face mask from the :class:`DetectedFace` for the requested
|
||||
mask type.
|
||||
|
||||
|
@ -448,7 +442,7 @@ class _Cache():
|
|||
def _get_localized_mask(self,
|
||||
filename: str,
|
||||
detected_face: DetectedFace,
|
||||
area: Literal["eye", "mouth"]) -> Optional[np.ndarray]:
|
||||
area: T.Literal["eye", "mouth"]) -> np.ndarray | None:
|
||||
""" Obtain a localized mask for the given area if it is required for training.
|
||||
|
||||
Parameters
|
||||
|
@ -486,7 +480,7 @@ class RingBuffer(): # pylint: disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
batch_size: int,
|
||||
image_shape: Tuple[int, int, int],
|
||||
image_shape: tuple[int, int, int],
|
||||
buffer_size: int = 2,
|
||||
dtype: str = "uint8") -> None:
|
||||
logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, "
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Handles Data Augmentation for feeding Faceswap Models """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent import futures
|
||||
import typing as T
|
||||
|
||||
from concurrent import futures
|
||||
from random import shuffle, choice
|
||||
from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -21,18 +20,14 @@ from lib.utils import FaceswapError
|
|||
from . import ImageAugmentation
|
||||
from .cache import get_cache, RingBuffer
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from lib.config import ConfigValueType
|
||||
from plugins.train.model._base import ModelBase
|
||||
from .cache import _Cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
BatchType = Tuple[np.ndarray, List[np.ndarray]]
|
||||
BatchType = tuple[np.ndarray, list[np.ndarray]]
|
||||
|
||||
|
||||
class DataGenerator():
|
||||
|
@ -57,10 +52,10 @@ class DataGenerator():
|
|||
objects of this size from the iterator.
|
||||
"""
|
||||
def __init__(self,
|
||||
config: Dict[str, "ConfigValueType"],
|
||||
model: "ModelBase",
|
||||
side: Literal["a", "b"],
|
||||
images: List[str],
|
||||
config: dict[str, ConfigValueType],
|
||||
model: ModelBase,
|
||||
side: T.Literal["a", "b"],
|
||||
images: list[str],
|
||||
batch_size: int) -> None:
|
||||
logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore
|
||||
"batch_size: %s, config: %s)", self.__class__.__name__, model.name, side,
|
||||
|
@ -83,11 +78,11 @@ class DataGenerator():
|
|||
self._buffer = RingBuffer(batch_size,
|
||||
(self._process_size, self._process_size, self._total_channels),
|
||||
dtype="uint8")
|
||||
self._face_cache: "_Cache" = get_cache(side,
|
||||
filenames=images,
|
||||
config=self._config,
|
||||
size=self._process_size,
|
||||
coverage_ratio=self._coverage_ratio)
|
||||
self._face_cache: _Cache = get_cache(side,
|
||||
filenames=images,
|
||||
config=self._config,
|
||||
size=self._process_size,
|
||||
coverage_ratio=self._coverage_ratio)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
|
@ -100,12 +95,12 @@ class DataGenerator():
|
|||
channels += 1
|
||||
|
||||
mults = [area for area in ["eye", "mouth"]
|
||||
if cast(int, self._config[f"{area}_multiplier"]) > 1]
|
||||
if T.cast(int, self._config[f"{area}_multiplier"]) > 1]
|
||||
if self._config["penalized_mask_loss"] and mults:
|
||||
channels += len(mults)
|
||||
return channels
|
||||
|
||||
def _get_output_sizes(self, model: "ModelBase") -> List[int]:
|
||||
def _get_output_sizes(self, model: ModelBase) -> list[int]:
|
||||
""" Obtain the size of each output tensor for the model.
|
||||
|
||||
Parameters
|
||||
|
@ -222,7 +217,7 @@ class DataGenerator():
|
|||
retval = self._process_batch(img_paths)
|
||||
yield retval
|
||||
|
||||
def _get_images_with_meta(self, filenames: List[str]) -> Tuple[np.ndarray, List[DetectedFace]]:
|
||||
def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]:
|
||||
""" Obtain the raw face images with associated :class:`DetectedFace` objects for this
|
||||
batch.
|
||||
|
||||
|
@ -253,9 +248,9 @@ class DataGenerator():
|
|||
return raw_faces, detected_faces
|
||||
|
||||
def _crop_to_coverage(self,
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
images: np.ndarray,
|
||||
detected_faces: List[DetectedFace],
|
||||
detected_faces: list[DetectedFace],
|
||||
batch: np.ndarray) -> None:
|
||||
""" Crops the training image out of the full extract image based on the centering and
|
||||
coveage used in the user's configuration settings.
|
||||
|
@ -286,7 +281,7 @@ class DataGenerator():
|
|||
for future in futures.as_completed(proc):
|
||||
batch[proc[future], ..., :3] = future.result()
|
||||
|
||||
def _apply_mask(self, detected_faces: List[DetectedFace], batch: np.ndarray) -> None:
|
||||
def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None:
|
||||
""" Applies the masks to the 4th channel of the batch.
|
||||
|
||||
If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1
|
||||
|
@ -312,7 +307,7 @@ class DataGenerator():
|
|||
logger.trace("side: %s, masks: %s, batch: %s", # type: ignore
|
||||
self._side, masks.shape, batch.shape)
|
||||
|
||||
def _process_batch(self, filenames: List[str]) -> BatchType:
|
||||
def _process_batch(self, filenames: list[str]) -> BatchType:
|
||||
""" Prepares data for feeding through subclassed methods.
|
||||
|
||||
If this is the first time a face has been loaded, then it's meta data is extracted from the
|
||||
|
@ -345,9 +340,9 @@ class DataGenerator():
|
|||
return feed, targets
|
||||
|
||||
def process_batch(self,
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
images: np.ndarray,
|
||||
detected_faces: List[DetectedFace],
|
||||
detected_faces: list[DetectedFace],
|
||||
batch: np.ndarray) -> BatchType:
|
||||
""" Override for processing the batch for the current generator.
|
||||
|
||||
|
@ -391,7 +386,7 @@ class DataGenerator():
|
|||
The input uint8 array
|
||||
"""
|
||||
return ne.evaluate("x / c",
|
||||
local_dict=dict(x=in_array, c=np.float32(255)),
|
||||
local_dict={"x": in_array, "c": np.float32(255)},
|
||||
casting="unsafe")
|
||||
|
||||
|
||||
|
@ -417,10 +412,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
objects of this size from the iterator.
|
||||
"""
|
||||
def __init__(self,
|
||||
config: Dict[str, "ConfigValueType"],
|
||||
model: "ModelBase",
|
||||
side: Literal["a", "b"],
|
||||
images: List[str],
|
||||
config: dict[str, ConfigValueType],
|
||||
model: ModelBase,
|
||||
side: T.Literal["a", "b"],
|
||||
images: list[str],
|
||||
batch_size: int) -> None:
|
||||
super().__init__(config, model, side, images, batch_size)
|
||||
self._augment_color = not model.command_line_arguments.no_augment_color
|
||||
|
@ -434,10 +429,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
self._processing = ImageAugmentation(batch_size,
|
||||
self._process_size,
|
||||
self._config)
|
||||
self._nearest_landmarks: Dict[str, Tuple[str, ...]] = {}
|
||||
self._nearest_landmarks: dict[str, tuple[str, ...]] = {}
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def _create_targets(self, batch: np.ndarray) -> List[np.ndarray]:
|
||||
def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]:
|
||||
""" Compile target images, with masks, for the model output sizes.
|
||||
|
||||
Parameters
|
||||
|
@ -467,9 +462,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
return retval
|
||||
|
||||
def process_batch(self,
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
images: np.ndarray,
|
||||
detected_faces: List[DetectedFace],
|
||||
detected_faces: list[DetectedFace],
|
||||
batch: np.ndarray) -> BatchType:
|
||||
""" Performs the augmentation and compiles target images and samples.
|
||||
|
||||
|
@ -525,7 +520,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
if self._warp_to_landmarks:
|
||||
landmarks = np.array([face.aligned.landmarks for face in detected_faces])
|
||||
batch_dst_pts = self._get_closest_match(filenames, landmarks)
|
||||
warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts)
|
||||
warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts}
|
||||
else:
|
||||
warp_kwargs = {}
|
||||
|
||||
|
@ -545,7 +540,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
|
||||
return feed, targets
|
||||
|
||||
def _get_closest_match(self, filenames: List[str], batch_src_points: np.ndarray) -> np.ndarray:
|
||||
def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray:
|
||||
""" Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest
|
||||
matched 68 point landmarks from the opposite training set.
|
||||
|
||||
|
@ -563,7 +558,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
"""
|
||||
logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore
|
||||
"src_points: '%s')", filenames, batch_src_points)
|
||||
lm_side: Literal["a", "b"] = "a" if self._side == "b" else "b"
|
||||
lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b"
|
||||
other_cache = get_cache(lm_side)
|
||||
landmarks = other_cache.aligned_landmarks
|
||||
|
||||
|
@ -584,9 +579,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
return batch_dst_points
|
||||
|
||||
def _cache_closest_matches(self,
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
batch_src_points: np.ndarray,
|
||||
landmarks: Dict[str, np.ndarray]) -> List[Tuple[str, ...]]:
|
||||
landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]:
|
||||
""" Cache the nearest landmarks for this batch
|
||||
|
||||
Parameters
|
||||
|
@ -602,7 +597,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met
|
|||
logger.trace("Caching closest matches") # type:ignore
|
||||
dst_landmarks = list(landmarks.items())
|
||||
dst_points = np.array([lm[1] for lm in dst_landmarks])
|
||||
batch_closest_matches: List[Tuple[str, ...]] = []
|
||||
batch_closest_matches: list[tuple[str, ...]] = []
|
||||
|
||||
for filename, src_points in zip(filenames, batch_src_points):
|
||||
closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10]
|
||||
|
@ -637,7 +632,7 @@ class PreviewDataGenerator(DataGenerator):
|
|||
"""
|
||||
def _create_samples(self,
|
||||
images: np.ndarray,
|
||||
detected_faces: List[DetectedFace]) -> List[np.ndarray]:
|
||||
detected_faces: list[DetectedFace]) -> list[np.ndarray]:
|
||||
""" Compile the 'sample' images. These are the 100% coverage images which hold the model
|
||||
output in the preview window.
|
||||
|
||||
|
@ -658,24 +653,25 @@ class PreviewDataGenerator(DataGenerator):
|
|||
output_size = self._output_sizes[-1]
|
||||
full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2))
|
||||
|
||||
assert self._config["centering"] in get_args(CenteringType)
|
||||
assert self._config["centering"] in T.get_args(CenteringType)
|
||||
retval = np.empty((full_size, full_size, 3), dtype="float32")
|
||||
retval = self._to_float32(np.array([AlignedFace(face.landmarks_xy,
|
||||
image=images[idx],
|
||||
centering=cast(CenteringType,
|
||||
self._config["centering"]),
|
||||
size=full_size,
|
||||
dtype="uint8",
|
||||
is_aligned=True).face
|
||||
for idx, face in enumerate(detected_faces)]))
|
||||
retval = self._to_float32(np.array([
|
||||
AlignedFace(face.landmarks_xy,
|
||||
image=images[idx],
|
||||
centering=T.cast(CenteringType,
|
||||
self._config["centering"]),
|
||||
size=full_size,
|
||||
dtype="uint8",
|
||||
is_aligned=True).face
|
||||
for idx, face in enumerate(detected_faces)]))
|
||||
|
||||
logger.trace("Processed samples: %s", retval.shape) # type: ignore
|
||||
return [retval]
|
||||
|
||||
def process_batch(self,
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
images: np.ndarray,
|
||||
detected_faces: List[DetectedFace],
|
||||
detected_faces: list[DetectedFace],
|
||||
batch: np.ndarray) -> BatchType:
|
||||
""" Creates the full size preview images and the sub-cropped images for feeding the model's
|
||||
predict function.
|
||||
|
|
|
@ -4,36 +4,30 @@
|
|||
If Tkinter is installed, then this will be used to manage the preview image, otherwise we
|
||||
fallback to opencv's imshow
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from threading import Event, Lock
|
||||
from time import sleep
|
||||
|
||||
from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
TriggerType = Dict[Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event]
|
||||
TriggerKeysType = Literal["m", "r", "s", "enter"]
|
||||
TriggerNamesType = Literal["toggle_mask", "refresh", "save", "quit"]
|
||||
TriggerType = dict[T.Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event]
|
||||
TriggerKeysType = T.Literal["m", "r", "s", "enter"]
|
||||
TriggerNamesType = T.Literal["toggle_mask", "refresh", "save", "quit"]
|
||||
|
||||
|
||||
class PreviewBuffer():
|
||||
""" A thread safe class for holding preview images """
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
self._images: Dict[str, "np.ndarray"] = {}
|
||||
self._images: dict[str, np.ndarray] = {}
|
||||
self._lock = Lock()
|
||||
self._updated = Event()
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
@ -43,7 +37,7 @@ class PreviewBuffer():
|
|||
""" bool: ``True`` when new images have been loaded into the preview buffer """
|
||||
return self._updated.is_set()
|
||||
|
||||
def add_image(self, name: str, image: "np.ndarray") -> None:
|
||||
def add_image(self, name: str, image: np.ndarray) -> None:
|
||||
""" Add an image to the preview buffer in a thread safe way """
|
||||
logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape)
|
||||
with self._lock:
|
||||
|
@ -51,7 +45,7 @@ class PreviewBuffer():
|
|||
logger.debug("Added images: %s", list(self._images))
|
||||
self._updated.set()
|
||||
|
||||
def get_images(self) -> Generator[Tuple[str, "np.ndarray"], None, None]:
|
||||
def get_images(self) -> Generator[tuple[str, np.ndarray], None, None]:
|
||||
""" Get the latest images from the preview buffer. When iterator is exhausted clears the
|
||||
:attr:`updated` event.
|
||||
|
||||
|
@ -86,15 +80,15 @@ class PreviewBase(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
preview_buffer: PreviewBuffer,
|
||||
triggers: Optional[TriggerType] = None) -> None:
|
||||
triggers: TriggerType | None = None) -> None:
|
||||
logger.debug("Initializing %s parent (triggers: %s)",
|
||||
self.__class__.__name__, triggers)
|
||||
self._triggers = triggers
|
||||
self._buffer = preview_buffer
|
||||
self._keymaps: Dict[TriggerKeysType, TriggerNamesType] = dict(m="toggle_mask",
|
||||
r="refresh",
|
||||
s="save",
|
||||
enter="quit")
|
||||
self._keymaps: dict[TriggerKeysType, TriggerNamesType] = {"m": "toggle_mask",
|
||||
"r": "refresh",
|
||||
"s": "save",
|
||||
"enter": "quit"}
|
||||
self._title = ""
|
||||
logger.debug("Initialized %s parent", self.__class__.__name__)
|
||||
|
||||
|
@ -141,7 +135,7 @@ class PreviewCV(PreviewBase): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Unable to import Tkinter. Falling back to OpenCV")
|
||||
super().__init__(preview_buffer, triggers=triggers)
|
||||
self._triggers: TriggerType = self._triggers
|
||||
self._windows: List[str] = []
|
||||
self._windows: list[str] = []
|
||||
|
||||
self._lookup = {ord(key): val
|
||||
for key, val in self._keymaps.items() if key != "enter"}
|
||||
|
|
|
@ -4,24 +4,25 @@
|
|||
If Tkinter is installed, then this will be used to manage the preview image, otherwise we
|
||||
fallback to opencv's imshow
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
|
||||
from datetime import datetime
|
||||
from platform import system
|
||||
from tkinter import ttk
|
||||
from math import ceil, floor
|
||||
|
||||
from typing import cast, List, Optional, Tuple, TYPE_CHECKING
|
||||
from PIL import Image, ImageTk
|
||||
|
||||
import cv2
|
||||
|
||||
from .preview_cv import PreviewBase, TriggerKeysType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from .preview_cv import PreviewBuffer, TriggerType
|
||||
|
||||
|
@ -38,18 +39,18 @@ class _Taskbar():
|
|||
taskbar: :class:`tkinter.ttk.Frame` or ``None``
|
||||
None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI
|
||||
"""
|
||||
def __init__(self, parent: tk.Frame, taskbar: Optional[ttk.Frame]) -> None:
|
||||
def __init__(self, parent: tk.Frame, taskbar: ttk.Frame | None) -> None:
|
||||
logger.debug("Initializing %s (parent: '%s', taskbar: %s)",
|
||||
self.__class__.__name__, parent, taskbar)
|
||||
self._is_standalone = taskbar is None
|
||||
self._gui_mapped: List[tk.Widget] = []
|
||||
self._gui_mapped: list[tk.Widget] = []
|
||||
self._frame = tk.Frame(parent) if taskbar is None else taskbar
|
||||
|
||||
self._min_max_scales = (20, 400)
|
||||
self._vars = dict(save=tk.BooleanVar(),
|
||||
scale=tk.StringVar(),
|
||||
slider=tk.IntVar(),
|
||||
interpolator=tk.IntVar())
|
||||
self._vars = {"save": tk.BooleanVar(),
|
||||
"scale": tk.StringVar(),
|
||||
"slider": tk.IntVar(),
|
||||
"interpolator": tk.IntVar()}
|
||||
self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST),
|
||||
("bicubic", cv2.INTER_CUBIC)]
|
||||
self._scale = self._add_scale_combo()
|
||||
|
@ -261,7 +262,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
|
|||
def __init__(self,
|
||||
parent: tk.Frame,
|
||||
scale_var: tk.StringVar,
|
||||
screen_dimensions: Tuple[int, int],
|
||||
screen_dimensions: tuple[int, int],
|
||||
is_standalone: bool) -> None:
|
||||
logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)",
|
||||
self.__class__.__name__, parent, scale_var, screen_dimensions)
|
||||
|
@ -272,7 +273,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors
|
|||
self._screen_dimensions = screen_dimensions
|
||||
self._var_scale = scale_var
|
||||
self._configure_scrollbars(frame)
|
||||
self._image: Optional[ImageTk.PhotoImage] = None
|
||||
self._image: ImageTk.PhotoImage | None = None
|
||||
self._image_id = self.create_image(self.width / 2,
|
||||
self.height / 2,
|
||||
anchor=tk.CENTER,
|
||||
|
@ -400,8 +401,8 @@ class _Image():
|
|||
logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)",
|
||||
self.__class__.__name__, save_variable, is_standalone)
|
||||
self._is_standalone = is_standalone
|
||||
self._source: Optional["np.ndarray"] = None
|
||||
self._display: Optional[ImageTk.PhotoImage] = None
|
||||
self._source: np.ndarray | None = None
|
||||
self._display: ImageTk.PhotoImage | None = None
|
||||
self._scale = 1.0
|
||||
self._interpolation = cv2.INTER_NEAREST
|
||||
|
||||
|
@ -416,7 +417,7 @@ class _Image():
|
|||
return self._display
|
||||
|
||||
@property
|
||||
def source(self) -> "np.ndarray":
|
||||
def source(self) -> np.ndarray:
|
||||
""" :class:`PIL.Image.Image`: The current source preview image """
|
||||
assert self._source is not None
|
||||
return self._source
|
||||
|
@ -426,7 +427,7 @@ class _Image():
|
|||
"""int: The current display scale as a percentage of original image size """
|
||||
return int(self._scale * 100)
|
||||
|
||||
def set_source_image(self, name: str, image: "np.ndarray") -> None:
|
||||
def set_source_image(self, name: str, image: np.ndarray) -> None:
|
||||
""" Set the source image to :attr:`source`
|
||||
|
||||
Parameters
|
||||
|
@ -542,7 +543,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
|
|||
self._taskbar = taskbar
|
||||
self._image = image
|
||||
|
||||
self._drag_data: List[float] = [0., 0.]
|
||||
self._drag_data: list[float] = [0., 0.]
|
||||
self._set_mouse_bindings()
|
||||
self._set_key_bindings(is_standalone)
|
||||
logger.debug("Initialized %s", self.__class__.__name__,)
|
||||
|
@ -604,7 +605,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods
|
|||
The key press event
|
||||
"""
|
||||
move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview
|
||||
visible = (move_axis()[1] - move_axis()[0])
|
||||
visible = move_axis()[1] - move_axis()[0]
|
||||
amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25
|
||||
logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore
|
||||
"amount: %s)", move_axis, visible, amount)
|
||||
|
@ -671,10 +672,10 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
|
|||
Default: `None`
|
||||
"""
|
||||
def __init__(self,
|
||||
preview_buffer: "PreviewBuffer",
|
||||
parent: Optional[tk.Widget] = None,
|
||||
taskbar: Optional[ttk.Frame] = None,
|
||||
triggers: Optional["TriggerType"] = None) -> None:
|
||||
preview_buffer: PreviewBuffer,
|
||||
parent: tk.Widget | None = None,
|
||||
taskbar: ttk.Frame | None = None,
|
||||
triggers: TriggerType | None = None) -> None:
|
||||
logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent)
|
||||
super().__init__(preview_buffer, triggers=triggers)
|
||||
self._is_standalone = parent is None
|
||||
|
@ -745,7 +746,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
|
|||
logger.info(" Save Preview: Ctrl+s")
|
||||
logger.info("---------------------------------------------------")
|
||||
|
||||
def _get_geometry(self) -> Tuple[int, int]:
|
||||
def _get_geometry(self) -> tuple[int, int]:
|
||||
""" Obtain the geometry of the current screen (standalone) or the dimensions of the widget
|
||||
holding the preview window (GUI).
|
||||
|
||||
|
@ -780,7 +781,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
|
|||
half_screen = tuple(x // 2 for x in self._screen_dimensions)
|
||||
min_scales = (half_screen[0] / self._image.source.shape[1],
|
||||
half_screen[1] / self._image.source.shape[0])
|
||||
min_scale = min(1.0, min(min_scales))
|
||||
min_scale = min(1.0, *min_scales)
|
||||
min_scale = (ceil(min_scale * 10)) * 10
|
||||
|
||||
eight_screen = tuple(x * 8 for x in self._screen_dimensions)
|
||||
|
@ -884,7 +885,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods
|
|||
if self._triggers is None: # Don't need triggers for GUI
|
||||
return
|
||||
keypress = "enter" if event.keysym == "Return" else event.keysym
|
||||
key = cast(TriggerKeysType, keypress)
|
||||
key = T.cast(TriggerKeysType, keypress)
|
||||
logger.debug("Processing keypress '%s'", key)
|
||||
if key == "r":
|
||||
print("") # Let log print on different line from loss output
|
||||
|
|
53
lib/utils.py
53
lib/utils.py
|
@ -1,11 +1,12 @@
|
|||
#!/usr/bin python3
|
||||
""" Utilities available across all scripts """
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tkinter as tk
|
||||
import typing as T
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
|
@ -14,18 +15,12 @@ from re import finditer
|
|||
from socket import timeout as socket_timeout, error as socket_error
|
||||
from threading import get_ident
|
||||
from time import time
|
||||
from typing import cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
from urllib import request, error as urlliberror
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from http.client import HTTPResponse
|
||||
|
||||
# Global variables
|
||||
|
@ -34,8 +29,8 @@ _image_extensions = [ # pylint:disable=invalid-name
|
|||
_video_extensions = [ # pylint:disable=invalid-name
|
||||
".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv",
|
||||
".ts", ".vob"]
|
||||
_TF_VERS: Optional[Tuple[int, int]] = None
|
||||
ValidBackends = Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
|
||||
_TF_VERS: tuple[int, int] | None = None
|
||||
ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"]
|
||||
|
||||
|
||||
class _Backend(): # pylint:disable=too-few-public-methods
|
||||
|
@ -44,7 +39,7 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
|
||||
If file doesn't exist and a variable hasn't been set, create the config file. """
|
||||
def __init__(self) -> None:
|
||||
self._backends: Dict[str, ValidBackends] = {"1": "cpu",
|
||||
self._backends: dict[str, ValidBackends] = {"1": "cpu",
|
||||
"2": "directml",
|
||||
"3": "nvidia",
|
||||
"4": "apple_silicon",
|
||||
|
@ -78,9 +73,9 @@ class _Backend(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
# Check if environment variable is set, if so use that
|
||||
if "FACESWAP_BACKEND" in os.environ:
|
||||
fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
|
||||
assert fs_backend in get_args(ValidBackends), (
|
||||
f"Faceswap backend must be one of {get_args(ValidBackends)}")
|
||||
fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower())
|
||||
assert fs_backend in T.get_args(ValidBackends), (
|
||||
f"Faceswap backend must be one of {T.get_args(ValidBackends)}")
|
||||
print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}")
|
||||
return fs_backend
|
||||
# Intercept for sphinx docs build
|
||||
|
@ -163,11 +158,11 @@ def set_backend(backend: str) -> None:
|
|||
>>> set_backend("nvidia")
|
||||
"""
|
||||
global _FS_BACKEND # pylint:disable=global-statement
|
||||
backend = cast(ValidBackends, backend.lower())
|
||||
backend = T.cast(ValidBackends, backend.lower())
|
||||
_FS_BACKEND = backend
|
||||
|
||||
|
||||
def get_tf_version() -> Tuple[int, int]:
|
||||
def get_tf_version() -> tuple[int, int]:
|
||||
""" Obtain the major. minor version of currently installed Tensorflow.
|
||||
|
||||
Returns
|
||||
|
@ -179,7 +174,7 @@ def get_tf_version() -> Tuple[int, int]:
|
|||
-------
|
||||
>>> from lib.utils import get_tf_version
|
||||
>>> get_tf_version()
|
||||
(2, 9)
|
||||
(2, 10)
|
||||
"""
|
||||
global _TF_VERS # pylint:disable=global-statement
|
||||
if _TF_VERS is None:
|
||||
|
@ -225,7 +220,7 @@ def get_folder(path: str, make_folder: bool = True) -> str:
|
|||
return path
|
||||
|
||||
|
||||
def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]:
|
||||
def get_image_paths(directory: str, extension: str | None = None) -> list[str]:
|
||||
""" Gets the image paths from a given directory.
|
||||
|
||||
The function searches for files with the specified extension(s) in the given directory, and
|
||||
|
@ -274,7 +269,7 @@ def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str
|
|||
return dir_contents
|
||||
|
||||
|
||||
def get_dpi() -> Optional[float]:
|
||||
def get_dpi() -> float | None:
|
||||
""" Gets the DPI (dots per inch) of the display screen.
|
||||
|
||||
Returns
|
||||
|
@ -338,7 +333,7 @@ def convert_to_secs(*args: int) -> int:
|
|||
return retval
|
||||
|
||||
|
||||
def full_path_split(path: str) -> List[str]:
|
||||
def full_path_split(path: str) -> list[str]:
|
||||
""" Split a file path into all of its parts.
|
||||
|
||||
Parameters
|
||||
|
@ -360,7 +355,7 @@ def full_path_split(path: str) -> List[str]:
|
|||
['relative', 'path', 'to', 'file.txt']]
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
allparts: List[str] = []
|
||||
allparts: list[str] = []
|
||||
while True:
|
||||
parts = os.path.split(path)
|
||||
if parts[0] == path: # sentinel for absolute paths
|
||||
|
@ -410,7 +405,7 @@ def set_system_verbosity(log_level: str):
|
|||
warnings.simplefilter(action='ignore', category=warncat)
|
||||
|
||||
|
||||
def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None:
|
||||
def deprecation_warning(function: str, additional_info: str | None = None) -> None:
|
||||
""" Log a deprecation warning message.
|
||||
|
||||
This function logs a warning message to indicate that the specified function has been
|
||||
|
@ -436,7 +431,7 @@ def deprecation_warning(function: str, additional_info: Optional[str] = None) ->
|
|||
logger.warning(msg)
|
||||
|
||||
|
||||
def camel_case_split(identifier: str) -> List[str]:
|
||||
def camel_case_split(identifier: str) -> list[str]:
|
||||
""" Split a camelCase string into a list of its individual parts
|
||||
|
||||
Parameters
|
||||
|
@ -541,7 +536,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
|||
>>> model_downloader = GetModel("s3fd_keras_v2.h5", 11)
|
||||
"""
|
||||
|
||||
def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None:
|
||||
def __init__(self, model_filename: str | list[str], git_model_id: int) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
if not isinstance(model_filename, list):
|
||||
model_filename = [model_filename]
|
||||
|
@ -576,7 +571,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
|||
return retval
|
||||
|
||||
@property
|
||||
def model_path(self) -> Union[str, List[str]]:
|
||||
def model_path(self) -> str | list[str]:
|
||||
""" str or list[str]: The model path(s) in the cache folder.
|
||||
|
||||
Example
|
||||
|
@ -587,7 +582,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
|||
'/path/to/s3fd_keras_v2.h5'
|
||||
"""
|
||||
paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename]
|
||||
retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths
|
||||
retval: str | list[str] = paths[0] if len(paths) == 1 else paths
|
||||
self.logger.trace(retval) # type:ignore[attr-defined]
|
||||
return retval
|
||||
|
||||
|
@ -662,7 +657,7 @@ class GetModel(): # pylint:disable=too-few-public-methods
|
|||
self._url_download, self._cache_dir)
|
||||
sys.exit(1)
|
||||
|
||||
def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None:
|
||||
def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None:
|
||||
""" Write the model zip file to disk.
|
||||
|
||||
Parameters
|
||||
|
@ -762,8 +757,8 @@ class DebugTimes():
|
|||
"""
|
||||
def __init__(self,
|
||||
show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None:
|
||||
self._times: Dict[str, List[float]] = {}
|
||||
self._steps: Dict[str, float] = {}
|
||||
self._times: dict[str, list[float]] = {}
|
||||
self._steps: dict[str, float] = {}
|
||||
self._interval = 1
|
||||
self._display = {"min": show_min, "mean": show_mean, "max": show_max}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ msgid ""
|
|||
msgstr ""
|
||||
"Project-Id-Version: PACKAGE VERSION\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-06-11 23:20+0100\n"
|
||||
"POT-Creation-Date: 2023-06-25 13:39+0100\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
|
@ -227,7 +227,8 @@ msgid ""
|
|||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:198 plugins/train/_config.py:223
|
||||
#: plugins/train/_config.py:238 plugins/train/_config.py:265
|
||||
#: plugins/train/_config.py:238 plugins/train/_config.py:256
|
||||
#: plugins/train/_config.py:290
|
||||
msgid "optimizer"
|
||||
msgstr ""
|
||||
|
||||
|
@ -274,21 +275,40 @@ msgid ""
|
|||
"epsilon to 0.001 (1e-3)."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:258
|
||||
#: plugins/train/_config.py:262
|
||||
msgid ""
|
||||
"Apply AutoClipping to the gradients. AutoClip analyzes the "
|
||||
"gradient weights and adjusts the normalization value dynamically to fit the "
|
||||
"data. Can help prevent NaNs and improve model optimization at the expense of "
|
||||
"VRAM. Ref: AutoClip: Adaptive Gradient Clipping for Source Separation "
|
||||
"Networks https://arxiv.org/abs/2007.14469"
|
||||
"When to save the Optimizer Weights. Saving the optimizer weights is not "
|
||||
"necessary and will increase the model file size 3x (and by extension the "
|
||||
"amount of time it takes to save the model). However, it can be useful to "
|
||||
"save these weights if you want to guarantee that a resumed model carries off "
|
||||
"exactly from where it left off, rather than spending a few hundred "
|
||||
"iterations catching up.\n"
|
||||
"\t never - Don't save optimizer weights.\n"
|
||||
"\t always - Save the optimizer weights at every save iteration. Model saving "
|
||||
"will take longer, due to the increased file size, but you will always have "
|
||||
"the last saved optimizer state in your model file.\n"
|
||||
"\t exit - Only save the optimizer weights when explicitly terminating a "
|
||||
"model. This can be when the model is actively stopped or when the target "
|
||||
"iterations are met. Note: If the training session ends because of another "
|
||||
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
|
||||
"optimizer weights will NOT be saved."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:271 plugins/train/_config.py:283
|
||||
#: plugins/train/_config.py:297 plugins/train/_config.py:314
|
||||
#: plugins/train/_config.py:283
|
||||
msgid ""
|
||||
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
|
||||
"and adjusts the normalization value dynamically to fit the data. Can help "
|
||||
"prevent NaNs and improve model optimization at the expense of VRAM. Ref: "
|
||||
"AutoClip: Adaptive Gradient Clipping for Source Separation Networks https://"
|
||||
"arxiv.org/abs/2007.14469"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:296 plugins/train/_config.py:308
|
||||
#: plugins/train/_config.py:322 plugins/train/_config.py:339
|
||||
msgid "network"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:273
|
||||
#: plugins/train/_config.py:298
|
||||
msgid ""
|
||||
"Use reflection padding rather than zero padding with convolutions. Each "
|
||||
"convolution must pad the image boundaries to maintain the proper sizing. "
|
||||
|
@ -297,21 +317,21 @@ msgid ""
|
|||
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:286
|
||||
#: plugins/train/_config.py:311
|
||||
msgid ""
|
||||
"Enable the Tensorflow GPU 'allow_growth' configuration "
|
||||
"option. This option prevents Tensorflow from allocating all of the GPU VRAM "
|
||||
"at launch but can lead to higher VRAM fragmentation and slower performance. "
|
||||
"Should only be enabled if you are receiving errors regarding 'cuDNN fails to "
|
||||
"initialize' when commencing training."
|
||||
"Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
|
||||
"prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
|
||||
"lead to higher VRAM fragmentation and slower performance. Should only be "
|
||||
"enabled if you are receiving errors regarding 'cuDNN fails to initialize' "
|
||||
"when commencing training."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:299
|
||||
#: plugins/train/_config.py:324
|
||||
msgid ""
|
||||
"NVIDIA GPUs can run operations in float16 faster than in "
|
||||
"float32. Mixed precision allows you to use a mix of float16 with float32, to "
|
||||
"get the performance benefits from float16 and the numeric stability benefits "
|
||||
"from float32.\n"
|
||||
"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
|
||||
"precision allows you to use a mix of float16 with float32, to get the "
|
||||
"performance benefits from float16 and the numeric stability benefits from "
|
||||
"float32.\n"
|
||||
"\n"
|
||||
"This is untested on DirectML backend, but will run on most Nvidia models. it "
|
||||
"will only speed up training on more recent GPUs. Those with compute "
|
||||
|
@ -322,7 +342,7 @@ msgid ""
|
|||
"the most benefit."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:316
|
||||
#: plugins/train/_config.py:341
|
||||
msgid ""
|
||||
"If a 'NaN' is generated in the model, this means that the model has "
|
||||
"corrupted and the model is likely to start deteriorating from this point on. "
|
||||
|
@ -331,11 +351,11 @@ msgid ""
|
|||
"rescue your model."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:329
|
||||
#: plugins/train/_config.py:354
|
||||
msgid "convert"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:331
|
||||
#: plugins/train/_config.py:356
|
||||
msgid ""
|
||||
"[GPU Only]. The number of faces to feed through the model at once when "
|
||||
"running the Convert process.\n"
|
||||
|
@ -345,27 +365,27 @@ msgid ""
|
|||
"size."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:350
|
||||
#: plugins/train/_config.py:375
|
||||
msgid ""
|
||||
"Loss configuration options\n"
|
||||
"Loss is the mechanism by which a Neural Network judges how well it thinks "
|
||||
"that it is recreating a face."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:357 plugins/train/_config.py:369
|
||||
#: plugins/train/_config.py:382 plugins/train/_config.py:402
|
||||
#: plugins/train/_config.py:414 plugins/train/_config.py:434
|
||||
#: plugins/train/_config.py:446 plugins/train/_config.py:466
|
||||
#: plugins/train/_config.py:482 plugins/train/_config.py:498
|
||||
#: plugins/train/_config.py:515
|
||||
#: plugins/train/_config.py:382 plugins/train/_config.py:394
|
||||
#: plugins/train/_config.py:407 plugins/train/_config.py:427
|
||||
#: plugins/train/_config.py:439 plugins/train/_config.py:459
|
||||
#: plugins/train/_config.py:471 plugins/train/_config.py:491
|
||||
#: plugins/train/_config.py:507 plugins/train/_config.py:523
|
||||
#: plugins/train/_config.py:540
|
||||
msgid "loss"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:361
|
||||
#: plugins/train/_config.py:386
|
||||
msgid "The loss function to use."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:373
|
||||
#: plugins/train/_config.py:398
|
||||
msgid ""
|
||||
"The second loss function to use. If using a structural based loss (such as "
|
||||
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
|
||||
|
@ -373,7 +393,7 @@ msgid ""
|
|||
"function with the loss_weight_2 option."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:388
|
||||
#: plugins/train/_config.py:413
|
||||
msgid ""
|
||||
"The amount of weight to apply to the second loss function.\n"
|
||||
"\n"
|
||||
|
@ -391,13 +411,13 @@ msgid ""
|
|||
"\t 0 - Disables the second loss function altogether."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:406
|
||||
#: plugins/train/_config.py:431
|
||||
msgid ""
|
||||
"The third loss function to use. You can adjust the weighting of this loss "
|
||||
"function with the loss_weight_3 option."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:420
|
||||
#: plugins/train/_config.py:445
|
||||
msgid ""
|
||||
"The amount of weight to apply to the third loss function.\n"
|
||||
"\n"
|
||||
|
@ -415,13 +435,13 @@ msgid ""
|
|||
"\t 0 - Disables the third loss function altogether."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:438
|
||||
#: plugins/train/_config.py:463
|
||||
msgid ""
|
||||
"The fourth loss function to use. You can adjust the weighting of this loss "
|
||||
"function with the loss_weight_3 option."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:452
|
||||
#: plugins/train/_config.py:477
|
||||
msgid ""
|
||||
"The amount of weight to apply to the fourth loss function.\n"
|
||||
"\n"
|
||||
|
@ -439,7 +459,7 @@ msgid ""
|
|||
"\t 0 - Disables the fourth loss function altogether."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:471
|
||||
#: plugins/train/_config.py:496
|
||||
msgid ""
|
||||
"The loss function to use when learning a mask.\n"
|
||||
"\t MAE - Mean absolute error will guide reconstructions of each pixel "
|
||||
|
@ -451,7 +471,7 @@ msgid ""
|
|||
"susceptible to outliers and typically produces slightly blurrier results."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:488
|
||||
#: plugins/train/_config.py:513
|
||||
msgid ""
|
||||
"The amount of priority to give to the eyes.\n"
|
||||
"\n"
|
||||
|
@ -464,7 +484,7 @@ msgid ""
|
|||
"NB: Penalized Mask Loss must be enable to use this option."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:504
|
||||
#: plugins/train/_config.py:529
|
||||
msgid ""
|
||||
"The amount of priority to give to the mouth.\n"
|
||||
"\n"
|
||||
|
@ -477,7 +497,7 @@ msgid ""
|
|||
"NB: Penalized Mask Loss must be enable to use this option."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:517
|
||||
#: plugins/train/_config.py:542
|
||||
msgid ""
|
||||
"Image loss function is weighted by mask presence. For areas of the image "
|
||||
"without the facial mask, reconstruction errors will be ignored while the "
|
||||
|
@ -485,12 +505,12 @@ msgid ""
|
|||
"attention on the core face area."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:528 plugins/train/_config.py:570
|
||||
#: plugins/train/_config.py:584 plugins/train/_config.py:593
|
||||
#: plugins/train/_config.py:553 plugins/train/_config.py:595
|
||||
#: plugins/train/_config.py:609 plugins/train/_config.py:618
|
||||
msgid "mask"
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:531
|
||||
#: plugins/train/_config.py:556
|
||||
msgid ""
|
||||
"The mask to be used for training. If you have selected 'Learn Mask' or "
|
||||
"'Penalized Mask Loss' you must select a value other than 'none'. The "
|
||||
|
@ -528,7 +548,7 @@ msgid ""
|
|||
"performance."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:572
|
||||
#: plugins/train/_config.py:597
|
||||
msgid ""
|
||||
"Apply gaussian blur to the mask input. This has the effect of smoothing the "
|
||||
"edges of the mask, which can help with poorly calculated masks and give less "
|
||||
|
@ -538,13 +558,13 @@ msgid ""
|
|||
"number."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:586
|
||||
#: plugins/train/_config.py:611
|
||||
msgid ""
|
||||
"Sets pixels that are near white to white and near black to black. Set to 0 "
|
||||
"for off."
|
||||
msgstr ""
|
||||
|
||||
#: plugins/train/_config.py:595
|
||||
#: plugins/train/_config.py:620
|
||||
msgid ""
|
||||
"Dedicate a portion of the model to learning how to duplicate the input mask. "
|
||||
"Increases VRAM usage in exchange for learning a quick ability to try to "
|
||||
|
|
Binary file not shown.
|
@ -7,8 +7,8 @@ msgid ""
|
|||
msgstr ""
|
||||
"Project-Id-Version: \n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-06-11 23:20+0100\n"
|
||||
"PO-Revision-Date: 2023-06-20 17:06+0100\n"
|
||||
"POT-Creation-Date: 2023-06-25 13:39+0100\n"
|
||||
"PO-Revision-Date: 2023-06-25 13:42+0100\n"
|
||||
"Last-Translator: \n"
|
||||
"Language-Team: \n"
|
||||
"Language: ru_RU\n"
|
||||
|
@ -354,7 +354,8 @@ msgstr ""
|
|||
"повлияет только на запуск новой модели."
|
||||
|
||||
#: plugins/train/_config.py:198 plugins/train/_config.py:223
|
||||
#: plugins/train/_config.py:238 plugins/train/_config.py:265
|
||||
#: plugins/train/_config.py:238 plugins/train/_config.py:256
|
||||
#: plugins/train/_config.py:290
|
||||
msgid "optimizer"
|
||||
msgstr "оптимизатор"
|
||||
|
||||
|
@ -435,7 +436,41 @@ msgstr ""
|
|||
"Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе "
|
||||
"значения \"-3\" эпсилон будет равен 0,001 (1e-3)."
|
||||
|
||||
#: plugins/train/_config.py:258
|
||||
#: plugins/train/_config.py:262
|
||||
msgid ""
|
||||
"When to save the Optimizer Weights. Saving the optimizer weights is not "
|
||||
"necessary and will increase the model file size 3x (and by extension the "
|
||||
"amount of time it takes to save the model). However, it can be useful to "
|
||||
"save these weights if you want to guarantee that a resumed model carries off "
|
||||
"exactly from where it left off, rather than spending a few hundred "
|
||||
"iterations catching up.\n"
|
||||
"\t never - Don't save optimizer weights.\n"
|
||||
"\t always - Save the optimizer weights at every save iteration. Model saving "
|
||||
"will take longer, due to the increased file size, but you will always have "
|
||||
"the last saved optimizer state in your model file.\n"
|
||||
"\t exit - Only save the optimizer weights when explicitly terminating a "
|
||||
"model. This can be when the model is actively stopped or when the target "
|
||||
"iterations are met. Note: If the training session ends because of another "
|
||||
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
|
||||
"optimizer weights will NOT be saved."
|
||||
msgstr ""
|
||||
"Когда сохранять веса оптимизатора. Сохранение весов оптимизатора не является "
|
||||
"необходимым и увеличит размер файла модели в 3 раза (и соответственно время, "
|
||||
"необходимое для сохранения модели). Однако может быть полезно сохранить эти "
|
||||
"веса, если вы хотите гарантировать, что возобновленная модель продолжит "
|
||||
"работу именно с того места, где она остановилась, а не тратит несколько "
|
||||
"сотен итераций на догонялки.\n"
|
||||
"\t never - не сохранять веса оптимизатора.\n"
|
||||
"\t always - сохранять веса оптимизатора при каждой итерации сохранения. "
|
||||
"Сохранение модели займет больше времени из-за увеличенного размера файла, но "
|
||||
"в файле модели всегда будет последнее сохраненное состояние оптимизатора.\n"
|
||||
"\t exit - сохранять веса оптимизатора только при явном завершении модели. "
|
||||
"Это может быть, когда модель активно останавливается или когда выполняются "
|
||||
"целевые итерации. Примечание. Если сеанс обучения завершается по другой "
|
||||
"причине (например, отключение питания, ошибка нехватки памяти, обнаружение "
|
||||
"NaN), веса оптимизатора НЕ будут сохранены."
|
||||
|
||||
#: plugins/train/_config.py:283
|
||||
msgid ""
|
||||
"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights "
|
||||
"and adjusts the normalization value dynamically to fit the data. Can help "
|
||||
|
@ -449,12 +484,12 @@ msgstr ""
|
|||
"ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping for Source "
|
||||
"Separation Networks [ТОЛЬКО на английском] https://arxiv.org/abs/2007.14469"
|
||||
|
||||
#: plugins/train/_config.py:271 plugins/train/_config.py:283
|
||||
#: plugins/train/_config.py:297 plugins/train/_config.py:314
|
||||
#: plugins/train/_config.py:296 plugins/train/_config.py:308
|
||||
#: plugins/train/_config.py:322 plugins/train/_config.py:339
|
||||
msgid "network"
|
||||
msgstr "сеть"
|
||||
|
||||
#: plugins/train/_config.py:273
|
||||
#: plugins/train/_config.py:298
|
||||
msgid ""
|
||||
"Use reflection padding rather than zero padding with convolutions. Each "
|
||||
"convolution must pad the image boundaries to maintain the proper sizing. "
|
||||
|
@ -468,7 +503,7 @@ msgstr ""
|
|||
"изображения.\n"
|
||||
"\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt"
|
||||
|
||||
#: plugins/train/_config.py:286
|
||||
#: plugins/train/_config.py:311
|
||||
msgid ""
|
||||
"Enable the Tensorflow GPU 'allow_growth' configuration option. This option "
|
||||
"prevents Tensorflow from allocating all of the GPU VRAM at launch but can "
|
||||
|
@ -483,7 +518,7 @@ msgstr ""
|
|||
"случае, если у вас появляются ошибки, рода 'cuDNN fails to initialize'(cuDNN "
|
||||
"не может инициализироваться) при начале тренировки."
|
||||
|
||||
#: plugins/train/_config.py:299
|
||||
#: plugins/train/_config.py:324
|
||||
msgid ""
|
||||
"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed "
|
||||
"precision allows you to use a mix of float16 with float32, to get the "
|
||||
|
@ -512,7 +547,7 @@ msgstr ""
|
|||
"ускорение. В основном RTX видеокарты и позже предлагают самое большое "
|
||||
"ускорение."
|
||||
|
||||
#: plugins/train/_config.py:316
|
||||
#: plugins/train/_config.py:341
|
||||
msgid ""
|
||||
"If a 'NaN' is generated in the model, this means that the model has "
|
||||
"corrupted and the model is likely to start deteriorating from this point on. "
|
||||
|
@ -526,11 +561,11 @@ msgstr ""
|
|||
"NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет "
|
||||
"возможность спасти вашу модель."
|
||||
|
||||
#: plugins/train/_config.py:329
|
||||
#: plugins/train/_config.py:354
|
||||
msgid "convert"
|
||||
msgstr "конвертирование"
|
||||
|
||||
#: plugins/train/_config.py:331
|
||||
#: plugins/train/_config.py:356
|
||||
msgid ""
|
||||
"[GPU Only]. The number of faces to feed through the model at once when "
|
||||
"running the Convert process.\n"
|
||||
|
@ -546,7 +581,7 @@ msgstr ""
|
|||
"конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда "
|
||||
"стоит снизить размер пачки."
|
||||
|
||||
#: plugins/train/_config.py:350
|
||||
#: plugins/train/_config.py:375
|
||||
msgid ""
|
||||
"Loss configuration options\n"
|
||||
"Loss is the mechanism by which a Neural Network judges how well it thinks "
|
||||
|
@ -556,20 +591,20 @@ msgstr ""
|
|||
"Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она "
|
||||
"воспроизводит лицо."
|
||||
|
||||
#: plugins/train/_config.py:357 plugins/train/_config.py:369
|
||||
#: plugins/train/_config.py:382 plugins/train/_config.py:402
|
||||
#: plugins/train/_config.py:414 plugins/train/_config.py:434
|
||||
#: plugins/train/_config.py:446 plugins/train/_config.py:466
|
||||
#: plugins/train/_config.py:482 plugins/train/_config.py:498
|
||||
#: plugins/train/_config.py:515
|
||||
#: plugins/train/_config.py:382 plugins/train/_config.py:394
|
||||
#: plugins/train/_config.py:407 plugins/train/_config.py:427
|
||||
#: plugins/train/_config.py:439 plugins/train/_config.py:459
|
||||
#: plugins/train/_config.py:471 plugins/train/_config.py:491
|
||||
#: plugins/train/_config.py:507 plugins/train/_config.py:523
|
||||
#: plugins/train/_config.py:540
|
||||
msgid "loss"
|
||||
msgstr "потери"
|
||||
|
||||
#: plugins/train/_config.py:361
|
||||
#: plugins/train/_config.py:386
|
||||
msgid "The loss function to use."
|
||||
msgstr "Какую функцию потерь стоит использовать."
|
||||
|
||||
#: plugins/train/_config.py:373
|
||||
#: plugins/train/_config.py:398
|
||||
msgid ""
|
||||
"The second loss function to use. If using a structural based loss (such as "
|
||||
"SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 "
|
||||
|
@ -581,7 +616,7 @@ msgstr ""
|
|||
"регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес "
|
||||
"этой функции потерь с помощью параметра loss_weight_2."
|
||||
|
||||
#: plugins/train/_config.py:388
|
||||
#: plugins/train/_config.py:413
|
||||
msgid ""
|
||||
"The amount of weight to apply to the second loss function.\n"
|
||||
"\n"
|
||||
|
@ -612,7 +647,7 @@ msgstr ""
|
|||
"4 раза перед добавлением к общей оценке потерь. \n"
|
||||
"\t 0 - Полностью отключает четвертую функцию потерь."
|
||||
|
||||
#: plugins/train/_config.py:406
|
||||
#: plugins/train/_config.py:431
|
||||
msgid ""
|
||||
"The third loss function to use. You can adjust the weighting of this loss "
|
||||
"function with the loss_weight_3 option."
|
||||
|
@ -620,7 +655,7 @@ msgstr ""
|
|||
"Третья используемая функция потерь. Вы можете настроить вес этой функции "
|
||||
"потерь с помощью параметра loss_weight_3."
|
||||
|
||||
#: plugins/train/_config.py:420
|
||||
#: plugins/train/_config.py:445
|
||||
msgid ""
|
||||
"The amount of weight to apply to the third loss function.\n"
|
||||
"\n"
|
||||
|
@ -651,7 +686,7 @@ msgstr ""
|
|||
"4 раза перед добавлением к общей оценке потерь. \n"
|
||||
"\t 0 - Полностью отключает четвертую функцию потерь."
|
||||
|
||||
#: plugins/train/_config.py:438
|
||||
#: plugins/train/_config.py:463
|
||||
msgid ""
|
||||
"The fourth loss function to use. You can adjust the weighting of this loss "
|
||||
"function with the loss_weight_3 option."
|
||||
|
@ -659,7 +694,7 @@ msgstr ""
|
|||
"Четвертая используемая функция потерь. Вы можете настроить вес этой функции "
|
||||
"потерь с помощью параметра 'loss_weight_4'."
|
||||
|
||||
#: plugins/train/_config.py:452
|
||||
#: plugins/train/_config.py:477
|
||||
msgid ""
|
||||
"The amount of weight to apply to the fourth loss function.\n"
|
||||
"\n"
|
||||
|
@ -690,7 +725,7 @@ msgstr ""
|
|||
"4 раза перед добавлением к общей оценке потерь. \n"
|
||||
"\t 0 - Полностью отключает четвертую функцию потерь."
|
||||
|
||||
#: plugins/train/_config.py:471
|
||||
#: plugins/train/_config.py:496
|
||||
msgid ""
|
||||
"The loss function to use when learning a mask.\n"
|
||||
"\t MAE - Mean absolute error will guide reconstructions of each pixel "
|
||||
|
@ -711,7 +746,7 @@ msgstr ""
|
|||
"данных. Как среднее значение, оно чувствительно к выбросам и обычно дает "
|
||||
"немного более размытые результаты."
|
||||
|
||||
#: plugins/train/_config.py:488
|
||||
#: plugins/train/_config.py:513
|
||||
msgid ""
|
||||
"The amount of priority to give to the eyes.\n"
|
||||
"\n"
|
||||
|
@ -731,7 +766,7 @@ msgstr ""
|
|||
"\n"
|
||||
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
|
||||
|
||||
#: plugins/train/_config.py:504
|
||||
#: plugins/train/_config.py:529
|
||||
msgid ""
|
||||
"The amount of priority to give to the mouth.\n"
|
||||
"\n"
|
||||
|
@ -751,7 +786,7 @@ msgstr ""
|
|||
"\n"
|
||||
"NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию."
|
||||
|
||||
#: plugins/train/_config.py:517
|
||||
#: plugins/train/_config.py:542
|
||||
msgid ""
|
||||
"Image loss function is weighted by mask presence. For areas of the image "
|
||||
"without the facial mask, reconstruction errors will be ignored while the "
|
||||
|
@ -763,12 +798,12 @@ msgstr ""
|
|||
"время как область лица с маской является приоритетной. Может повысить общее "
|
||||
"качество за счет концентрации внимания на основной области лица."
|
||||
|
||||
#: plugins/train/_config.py:528 plugins/train/_config.py:570
|
||||
#: plugins/train/_config.py:584 plugins/train/_config.py:593
|
||||
#: plugins/train/_config.py:553 plugins/train/_config.py:595
|
||||
#: plugins/train/_config.py:609 plugins/train/_config.py:618
|
||||
msgid "mask"
|
||||
msgstr "маска"
|
||||
|
||||
#: plugins/train/_config.py:531
|
||||
#: plugins/train/_config.py:556
|
||||
msgid ""
|
||||
"The mask to be used for training. If you have selected 'Learn Mask' or "
|
||||
"'Penalized Mask Loss' you must select a value other than 'none'. The "
|
||||
|
@ -840,7 +875,7 @@ msgstr ""
|
|||
"сообщества и для дальнейшего описания нуждается в тестировании. Профильные "
|
||||
"лица могут иметь низкую производительность."
|
||||
|
||||
#: plugins/train/_config.py:572
|
||||
#: plugins/train/_config.py:597
|
||||
msgid ""
|
||||
"Apply gaussian blur to the mask input. This has the effect of smoothing the "
|
||||
"edges of the mask, which can help with poorly calculated masks and give less "
|
||||
|
@ -856,7 +891,7 @@ msgstr ""
|
|||
"должно быть нечетным, если передано четное число, то оно будет округлено до "
|
||||
"следующего нечетного числа."
|
||||
|
||||
#: plugins/train/_config.py:586
|
||||
#: plugins/train/_config.py:611
|
||||
msgid ""
|
||||
"Sets pixels that are near white to white and near black to black. Set to 0 "
|
||||
"for off."
|
||||
|
@ -864,7 +899,7 @@ msgstr ""
|
|||
"Устанавливает пиксели, которые почти белые - в белые и которые почти черные "
|
||||
"- в черные. Установите 0, чтобы выключить."
|
||||
|
||||
#: plugins/train/_config.py:595
|
||||
#: plugins/train/_config.py:620
|
||||
msgid ""
|
||||
"Dedicate a portion of the model to learning how to duplicate the input mask. "
|
||||
"Increases VRAM usage in exchange for learning a quick ability to try to "
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Plugin to blend the edges of the face between the swap and the original face. """
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Optional, Tuple
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -11,12 +10,6 @@ from lib.align import BlurMask, DetectedFace
|
|||
from lib.config import FaceswapConfig
|
||||
from plugins.convert._config import Config
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -44,8 +37,8 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
mask_type: str,
|
||||
output_size: int,
|
||||
coverage_ratio: float,
|
||||
configfile: Optional[str] = None,
|
||||
config: Optional[FaceswapConfig] = None) -> None:
|
||||
configfile: str | None = None,
|
||||
config: FaceswapConfig | None = None) -> None:
|
||||
logger.debug("Initializing %s: (mask_type: '%s', output_size: %s, coverage_ratio: %s, "
|
||||
"configfile: %s, config: %s)", self.__class__.__name__, mask_type,
|
||||
coverage_ratio, output_size, configfile, config)
|
||||
|
@ -61,8 +54,8 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
self._do_erode = any(amount != 0 for amount in self._erodes)
|
||||
|
||||
def _set_config(self,
|
||||
configfile: Optional[str],
|
||||
config: Optional[FaceswapConfig]) -> dict:
|
||||
configfile: str | None,
|
||||
config: FaceswapConfig | None) -> dict:
|
||||
""" Set the correct configuration for the plugin based on whether a config file
|
||||
or pre-loaded config has been passed in.
|
||||
|
||||
|
@ -123,8 +116,8 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
detected_face: DetectedFace,
|
||||
source_offset: np.ndarray,
|
||||
target_offset: np.ndarray,
|
||||
centering: Literal["legacy", "face", "head"],
|
||||
predicted_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
centering: T.Literal["legacy", "face", "head"],
|
||||
predicted_mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Obtain the requested mask type and perform any defined mask manipulations.
|
||||
|
||||
Parameters
|
||||
|
@ -171,8 +164,8 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
|
||||
def _get_mask(self,
|
||||
detected_face: DetectedFace,
|
||||
predicted_mask: Optional[np.ndarray],
|
||||
centering: Literal["legacy", "face", "head"],
|
||||
predicted_mask: np.ndarray | None,
|
||||
centering: T.Literal["legacy", "face", "head"],
|
||||
source_offset: np.ndarray,
|
||||
target_offset: np.ndarray) -> np.ndarray:
|
||||
""" Return the requested mask with any requested blurring applied.
|
||||
|
@ -229,7 +222,7 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
|
||||
def _get_stored_mask(self,
|
||||
detected_face: DetectedFace,
|
||||
centering: Literal["legacy", "face", "head"],
|
||||
centering: T.Literal["legacy", "face", "head"],
|
||||
source_offset: np.ndarray,
|
||||
target_offset: np.ndarray) -> np.ndarray:
|
||||
""" get the requested stored mask from the detected face object.
|
||||
|
@ -303,7 +296,7 @@ class Mask(): # pylint:disable=too-few-public-methods
|
|||
|
||||
return eroded[..., None]
|
||||
|
||||
def _get_erosion_kernels(self, mask: np.ndarray) -> List[np.ndarray]:
|
||||
def _get_erosion_kernels(self, mask: np.ndarray) -> list[np.ndarray]:
|
||||
""" Get the erosion kernels for each of the center, left, top right and bottom erosions.
|
||||
|
||||
An approximation is made based on the number of positive pixels within the mask to create
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from typing import Any, List, Optional
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -14,7 +13,7 @@ from plugins.convert._config import Config
|
|||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_config(plugin_name: str, configfile: Optional[str] = None) -> dict:
|
||||
def get_config(plugin_name: str, configfile: str | None = None) -> dict:
|
||||
""" Obtain the configuration settings for the writer plugin.
|
||||
|
||||
Parameters
|
||||
|
@ -44,7 +43,7 @@ class Output():
|
|||
The full path to a custom configuration ini file. If ``None`` is passed
|
||||
then the file is loaded from the default location. Default: ``None``.
|
||||
"""
|
||||
def __init__(self, output_folder: str, configfile: Optional[str] = None) -> None:
|
||||
def __init__(self, output_folder: str, configfile: str | None = None) -> None:
|
||||
logger.debug("Initializing %s: (output_folder: '%s')",
|
||||
self.__class__.__name__, output_folder)
|
||||
self.config: dict = get_config(".".join(self.__module__.split(".")[-2:]),
|
||||
|
@ -69,7 +68,7 @@ class Output():
|
|||
retval = hasattr(self, "frame_order")
|
||||
return retval
|
||||
|
||||
def output_filename(self, filename: str, separate_mask: bool = False) -> List[str]:
|
||||
def output_filename(self, filename: str, separate_mask: bool = False) -> list[str]:
|
||||
""" Obtain the full path for the output file, including the correct extension, for the
|
||||
given input filename.
|
||||
|
||||
|
@ -124,7 +123,7 @@ class Output():
|
|||
logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore
|
||||
logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore
|
||||
|
||||
def write(self, filename: str, image: Any) -> None:
|
||||
def write(self, filename: str, image: T.Any) -> None:
|
||||
""" Override for specific frame writing method.
|
||||
|
||||
Parameters
|
||||
|
@ -137,7 +136,7 @@ class Output():
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument
|
||||
def pre_encode(self, image: np.ndarray) -> T.Any: # pylint: disable=unused-argument
|
||||
""" Some writer plugins support the pre-encoding of images prior to saving out. As
|
||||
patching is done in multiple threads, but writing is done in a single thread, it can
|
||||
speed up the process to do any pre-encoding as part of the converter process.
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Video output writer for faceswap.py converter """
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as T
|
||||
|
||||
from math import ceil
|
||||
from subprocess import CalledProcessError, check_output, STDOUT
|
||||
from typing import cast, Generator, List, Optional, Tuple
|
||||
|
||||
import imageio
|
||||
import imageio_ffmpeg as im_ffm
|
||||
|
@ -11,6 +13,9 @@ import numpy as np
|
|||
|
||||
from ._base import Output, logger
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
class Writer(Output):
|
||||
""" Video output writer using imageio-ffmpeg.
|
||||
|
@ -32,7 +37,7 @@ class Writer(Output):
|
|||
def __init__(self,
|
||||
output_folder: str,
|
||||
total_count: int,
|
||||
frame_ranges: Optional[List[Tuple[int, int]]],
|
||||
frame_ranges: list[tuple[int, int]] | None,
|
||||
source_video: str,
|
||||
**kwargs) -> None:
|
||||
super().__init__(output_folder, **kwargs)
|
||||
|
@ -40,11 +45,11 @@ class Writer(Output):
|
|||
total_count, frame_ranges, source_video)
|
||||
self._source_video: str = source_video
|
||||
self._output_filename: str = self._get_output_filename()
|
||||
self._frame_ranges: Optional[List[Tuple[int, int]]] = frame_ranges
|
||||
self.frame_order: List[int] = self._set_frame_order(total_count)
|
||||
self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame
|
||||
self._frame_ranges: list[tuple[int, int]] | None = frame_ranges
|
||||
self.frame_order: list[int] = self._set_frame_order(total_count)
|
||||
self._output_dimensions: str | None = None # Fix dims on 1st received frame
|
||||
# Need to know dimensions of first frame, so set writer then
|
||||
self._writer: Optional[Generator[None, np.ndarray, None]] = None
|
||||
self._writer: Generator[None, np.ndarray, None] | None = None
|
||||
|
||||
@property
|
||||
def _valid_tunes(self) -> dict:
|
||||
|
@ -63,7 +68,7 @@ class Writer(Output):
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _output_params(self) -> List[str]:
|
||||
def _output_params(self) -> list[str]:
|
||||
""" list: The FFMPEG Output parameters """
|
||||
codec = self.config["codec"]
|
||||
tune = self.config["tune"]
|
||||
|
@ -86,11 +91,11 @@ class Writer(Output):
|
|||
return output_args
|
||||
|
||||
@property
|
||||
def _audio_codec(self) -> Optional[str]:
|
||||
def _audio_codec(self) -> str | None:
|
||||
""" str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default)
|
||||
or ``None`` if skip muxing has been selected in configuration options, or if frame ranges
|
||||
have been passed in the command line arguments. """
|
||||
retval: Optional[str] = "copy"
|
||||
retval: str | None = "copy"
|
||||
if self.config["skip_mux"]:
|
||||
logger.info("Skipping audio muxing due to configuration settings.")
|
||||
retval = None
|
||||
|
@ -169,7 +174,7 @@ class Writer(Output):
|
|||
logger.info("Outputting to: '%s'", retval)
|
||||
return retval
|
||||
|
||||
def _set_frame_order(self, total_count: int) -> List[int]:
|
||||
def _set_frame_order(self, total_count: int) -> list[int]:
|
||||
""" Obtain the full list of frames to be converted in order.
|
||||
|
||||
Parameters
|
||||
|
@ -191,7 +196,7 @@ class Writer(Output):
|
|||
logger.debug("frame_order: %s", retval)
|
||||
return retval
|
||||
|
||||
def _get_writer(self, frame_dims: Tuple[int, int]) -> Generator[None, np.ndarray, None]:
|
||||
def _get_writer(self, frame_dims: tuple[int, int]) -> Generator[None, np.ndarray, None]:
|
||||
""" Add the requested encoding options and return the writer.
|
||||
|
||||
Parameters
|
||||
|
@ -238,13 +243,13 @@ class Writer(Output):
|
|||
logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined]
|
||||
filename, image.shape)
|
||||
if not self._output_dimensions:
|
||||
input_dims = cast(Tuple[int, int], image.shape[:2])
|
||||
input_dims = T.cast(tuple[int, int], image.shape[:2])
|
||||
self._set_dimensions(input_dims)
|
||||
self._writer = self._get_writer(input_dims)
|
||||
self.cache_frame(filename, image)
|
||||
self._save_from_cache()
|
||||
|
||||
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
|
||||
def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
|
||||
""" Set the attribute :attr:`_output_dimensions` based on the first frame received.
|
||||
This protects against different sized images coming in and ensures all images are written
|
||||
to ffmpeg at the same size. Dimensions are mapped to a macro block size 8.
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Animated GIF writer for faceswap.py converter """
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
|
||||
from ._base import Output, logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from imageio.core import format as im_format # noqa:F401
|
||||
|
||||
|
||||
|
@ -31,15 +32,16 @@ class Writer(Output):
|
|||
def __init__(self,
|
||||
output_folder: str,
|
||||
total_count: int,
|
||||
frame_ranges: Optional[List[Tuple[int, int]]],
|
||||
frame_ranges: list[tuple[int, int]] | None,
|
||||
**kwargs) -> None:
|
||||
logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges)
|
||||
super().__init__(output_folder, **kwargs)
|
||||
self.frame_order: List[int] = self._set_frame_order(total_count, frame_ranges)
|
||||
self._output_dimensions: Optional[Tuple[int, int]] = None # Fix dims on 1st received frame
|
||||
self.frame_order: list[int] = self._set_frame_order(total_count, frame_ranges)
|
||||
# Fix dims on 1st received frame
|
||||
self._output_dimensions: tuple[int, int] | None = None
|
||||
# Need to know dimensions of first frame, so set writer then
|
||||
self._writer: Optional[imageio.plugins.pillowmulti.GIFFormat.Writer] = None
|
||||
self._gif_file: Optional[str] = None # Set filename based on first file seen
|
||||
self._writer: imageio.plugins.pillowmulti.GIFFormat.Writer | None = None
|
||||
self._gif_file: str | None = None # Set filename based on first file seen
|
||||
|
||||
@property
|
||||
def _gif_params(self) -> dict:
|
||||
|
@ -50,7 +52,7 @@ class Writer(Output):
|
|||
|
||||
@staticmethod
|
||||
def _set_frame_order(total_count: int,
|
||||
frame_ranges: Optional[List[Tuple[int, int]]]) -> List[int]:
|
||||
frame_ranges: list[tuple[int, int]] | None) -> list[int]:
|
||||
""" Obtain the full list of frames to be converted in order.
|
||||
|
||||
Parameters
|
||||
|
@ -75,7 +77,7 @@ class Writer(Output):
|
|||
logger.debug("frame_order: %s", retval)
|
||||
return retval
|
||||
|
||||
def _get_writer(self) -> "im_format.Format.Writer":
|
||||
def _get_writer(self) -> im_format.Format.Writer:
|
||||
""" Obtain the GIF writer with the requested GIF encoding options.
|
||||
|
||||
Returns
|
||||
|
@ -145,7 +147,7 @@ class Writer(Output):
|
|||
self._gif_file = retval
|
||||
logger.info("Outputting to: '%s'", self._gif_file)
|
||||
|
||||
def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None:
|
||||
def _set_dimensions(self, frame_dims: tuple[int, int]) -> None:
|
||||
""" Set the attribute :attr:`_output_dimensions` based on the first frame received. This
|
||||
protects against different sized images coming in and ensure all images get written to the
|
||||
Gif at the sema dimensions. """
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
""" Image output writer for faceswap.py converter
|
||||
Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
@ -37,7 +35,7 @@ class Writer(Output):
|
|||
"transparency. Changing output format to 'png'")
|
||||
self.config["format"] = "png"
|
||||
|
||||
def _get_save_args(self) -> Tuple[int, ...]:
|
||||
def _get_save_args(self) -> tuple[int, ...]:
|
||||
""" Obtain the save parameters for the file format.
|
||||
|
||||
Returns
|
||||
|
@ -46,7 +44,7 @@ class Writer(Output):
|
|||
The OpenCV specific arguments for the selected file format
|
||||
"""
|
||||
filetype = self.config["format"]
|
||||
args: Tuple[int, ...] = tuple()
|
||||
args: tuple[int, ...] = tuple()
|
||||
if filetype == "jpg" and self.config["jpg_quality"] > 0:
|
||||
args = (cv2.IMWRITE_JPEG_QUALITY,
|
||||
self.config["jpg_quality"])
|
||||
|
@ -56,7 +54,7 @@ class Writer(Output):
|
|||
logger.debug(args)
|
||||
return args
|
||||
|
||||
def write(self, filename: str, image: List[bytes]) -> None:
|
||||
def write(self, filename: str, image: list[bytes]) -> None:
|
||||
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out
|
||||
the encoded mask to a sub-folder in the output directory.
|
||||
|
||||
|
@ -77,7 +75,7 @@ class Writer(Output):
|
|||
except Exception as err: # pylint: disable=broad-except
|
||||
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
|
||||
|
||||
def pre_encode(self, image: np.ndarray) -> List[bytes]:
|
||||
def pre_encode(self, image: np.ndarray) -> list[bytes]:
|
||||
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Image output writer for faceswap.py converter """
|
||||
|
||||
from typing import Dict, List, Union
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
@ -25,7 +23,7 @@ class Writer(Output):
|
|||
super().__init__(output_folder, **kwargs)
|
||||
self._check_transparency_format()
|
||||
# Correct format namings for writing to byte stream
|
||||
self._format_dict = dict(jpg="JPEG", jp2="JPEG 2000", tif="TIFF")
|
||||
self._format_dict = {"jpg": "JPEG", "jp2": "JPEG 2000", "tif": "TIFF"}
|
||||
self._separate_mask = self.config["draw_transparent"] and self.config["separate_mask"]
|
||||
self._kwargs = self._get_save_kwargs()
|
||||
|
||||
|
@ -38,7 +36,7 @@ class Writer(Output):
|
|||
"transparency. Changing output format to 'png'")
|
||||
self.config["format"] = "png"
|
||||
|
||||
def _get_save_kwargs(self) -> Dict[str, Union[bool, int, str]]:
|
||||
def _get_save_kwargs(self) -> dict[str, bool | int | str]:
|
||||
""" Return the save parameters for the file format
|
||||
|
||||
Returns
|
||||
|
@ -59,7 +57,7 @@ class Writer(Output):
|
|||
logger.debug(kwargs)
|
||||
return kwargs
|
||||
|
||||
def write(self, filename: str, image: List[BytesIO]) -> None:
|
||||
def write(self, filename: str, image: list[BytesIO]) -> None:
|
||||
""" Write out the pre-encoded image to disk. If separate mask has been selected, write out
|
||||
the encoded mask to a sub-folder in the output directory.
|
||||
|
||||
|
@ -80,7 +78,7 @@ class Writer(Output):
|
|||
except Exception as err: # pylint: disable=broad-except
|
||||
logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
|
||||
|
||||
def pre_encode(self, image: np.ndarray) -> List[BytesIO]:
|
||||
def pre_encode(self, image: np.ndarray) -> list[BytesIO]:
|
||||
""" Pre_encode the image in lib/convert.py threads as it is a LOT quicker
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
""" Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and
|
||||
:mod:`~plugins.extract.mask` Plugins
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, Generator, List, Optional,
|
||||
Sequence, Union, Tuple, TYPE_CHECKING)
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa
|
||||
|
@ -18,12 +17,8 @@ from lib.utils import GetModel, FaceswapError
|
|||
from ._config import Config
|
||||
from .pipeline import ExtractMedia
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from queue import Queue
|
||||
import cv2
|
||||
from lib.align import DetectedFace
|
||||
|
@ -37,7 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
# TODO Run with warnings mode
|
||||
|
||||
|
||||
def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str, Any]:
|
||||
def _get_config(plugin_name: str, configfile: str | None = None) -> dict[str, T.Any]:
|
||||
""" Return the configuration for the requested model
|
||||
|
||||
Parameters
|
||||
|
@ -56,7 +51,7 @@ def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str,
|
|||
return Config(plugin_name, configfile=configfile).config_dict
|
||||
|
||||
|
||||
BatchType = Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"]
|
||||
BatchType = T.Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -84,13 +79,12 @@ class ExtractorBatch:
|
|||
data: dict
|
||||
Any specific data required during the processing phase for a particular plugin
|
||||
"""
|
||||
image: List[np.ndarray] = field(default_factory=list)
|
||||
detected_faces: Sequence[Union["DetectedFace",
|
||||
List["DetectedFace"]]] = field(default_factory=list)
|
||||
filename: List[str] = field(default_factory=list)
|
||||
image: list[np.ndarray] = field(default_factory=list)
|
||||
detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list)
|
||||
filename: list[str] = field(default_factory=list)
|
||||
feed: np.ndarray = np.array([])
|
||||
prediction: np.ndarray = np.array([])
|
||||
data: List[Dict[str, Any]] = field(default_factory=list)
|
||||
data: list[dict[str, T.Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class Extractor():
|
||||
|
@ -157,10 +151,10 @@ class Extractor():
|
|||
|
||||
"""
|
||||
def __init__(self,
|
||||
git_model_id: Optional[int] = None,
|
||||
model_filename: Optional[Union[str, List[str]]] = None,
|
||||
exclude_gpus: Optional[List[int]] = None,
|
||||
configfile: Optional[str] = None,
|
||||
git_model_id: int | None = None,
|
||||
model_filename: str | list[str] | None = None,
|
||||
exclude_gpus: list[int] | None = None,
|
||||
configfile: str | None = None,
|
||||
instance: int = 0) -> None:
|
||||
logger.debug("Initializing %s: (git_model_id: %s, model_filename: %s, exclude_gpus: %s, "
|
||||
"configfile: %s, instance: %s, )", self.__class__.__name__, git_model_id,
|
||||
|
@ -176,9 +170,9 @@ class Extractor():
|
|||
be a list of strings """
|
||||
|
||||
# << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> #
|
||||
self.name: Optional[str] = None
|
||||
self.name: str | None = None
|
||||
self.input_size = 0
|
||||
self.color_format: Literal["BGR", "RGB", "GRAY"] = "BGR"
|
||||
self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR"
|
||||
self.vram = 0
|
||||
self.vram_warnings = 0 # Will run at this with warnings
|
||||
self.vram_per_batch = 0
|
||||
|
@ -187,7 +181,7 @@ class Extractor():
|
|||
self.queue_size = 1
|
||||
""" int: Queue size for all internal queues. Set in :func:`initialize()` """
|
||||
|
||||
self.model: Optional[Union["KSession", "cv2.dnn.Net"]] = None
|
||||
self.model: KSession | cv2.dnn.Net | None = None
|
||||
"""varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """
|
||||
|
||||
# For detectors that support batching, this should be set to the calculated batch size
|
||||
|
@ -196,26 +190,26 @@ class Extractor():
|
|||
""" int: Batchsize for feeding this model. The number of images the model should
|
||||
feed through at once. """
|
||||
|
||||
self._queues: Dict[str, "Queue"] = {}
|
||||
self._queues: dict[str, Queue] = {}
|
||||
""" dict: in + out queues and internal queues for this plugin, """
|
||||
|
||||
self._threads: List[MultiThread] = []
|
||||
self._threads: list[MultiThread] = []
|
||||
""" list: Internal threads for this plugin """
|
||||
|
||||
self._extract_media: Dict[str, ExtractMedia] = {}
|
||||
self._extract_media: dict[str, ExtractMedia] = {}
|
||||
""" dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being
|
||||
processed. Stored at input for pairing back up on output of extractor process """
|
||||
|
||||
# << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> #
|
||||
self._plugin_type: Optional[Literal["align", "detect", "recognition", "mask"]] = None
|
||||
self._plugin_type: T.Literal["align", "detect", "recognition", "mask"] | None = None
|
||||
""" str: Plugin type. ``detect`, ``align``, ``recognise`` or ``mask`` set in
|
||||
``<plugin_type>._base`` """
|
||||
|
||||
# << Objects for splitting frame's detected faces and rejoining them >>
|
||||
# << for post-detector pliugins >>
|
||||
self._faces_per_filename: Dict[str, int] = {} # Tracking for recompiling batches
|
||||
self._rollover: Optional[ExtractMedia] = None # batch rollover items
|
||||
self._output_faces: List["DetectedFace"] = [] # Recompiled output faces from plugin
|
||||
self._faces_per_filename: dict[str, int] = {} # Tracking for recompiling batches
|
||||
self._rollover: ExtractMedia | None = None # batch rollover items
|
||||
self._output_faces: list[DetectedFace] = [] # Recompiled output faces from plugin
|
||||
|
||||
logger.debug("Initialized _base %s", self.__class__.__name__)
|
||||
|
||||
|
@ -361,7 +355,7 @@ class Extractor():
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_batch(self, queue: "Queue") -> Tuple[bool, BatchType]:
|
||||
def get_batch(self, queue: Queue) -> tuple[bool, BatchType]:
|
||||
""" **Override method** (at `<plugin_type>` level)
|
||||
|
||||
This method should be overridden at the `<plugin_type>` level (IE.
|
||||
|
@ -403,7 +397,7 @@ class Extractor():
|
|||
for thread in self._threads:
|
||||
thread.check_and_raise_error()
|
||||
|
||||
def rollover_collector(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia]:
|
||||
def rollover_collector(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia:
|
||||
""" For extractors after the Detectors, the number of detected faces per frame vs extractor
|
||||
batch size mean that faces will need to be split/re-joined with frames. The rollover
|
||||
collector can be used to rollover items that don't fit in a batch.
|
||||
|
@ -425,7 +419,7 @@ class Extractor():
|
|||
if self._rollover is not None:
|
||||
logger.trace("Getting from _rollover: (filename: `%s`, faces: %s)", # type:ignore
|
||||
self._rollover.filename, len(self._rollover.detected_faces))
|
||||
item: Union[Literal["EOF"], ExtractMedia] = self._rollover
|
||||
item: T.Literal["EOF"] | ExtractMedia = self._rollover
|
||||
self._rollover = None
|
||||
else:
|
||||
next_item = self._get_item(queue)
|
||||
|
@ -442,9 +436,8 @@ class Extractor():
|
|||
# <<< INIT METHODS >>> #
|
||||
@classmethod
|
||||
def _get_model(cls,
|
||||
git_model_id: Optional[int],
|
||||
model_filename: Optional[Union[str, List[str]]]
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
git_model_id: int | None,
|
||||
model_filename: str | list[str] | None) -> str | list[str] | None:
|
||||
""" Check if model is available, if not, download and unzip it """
|
||||
if model_filename is None:
|
||||
logger.debug("No model_filename specified. Returning None")
|
||||
|
@ -496,9 +489,9 @@ class Extractor():
|
|||
self.name, self._plugin_type.title(), self.batchsize)
|
||||
|
||||
def _add_queues(self,
|
||||
in_queue: "Queue",
|
||||
out_queue: "Queue",
|
||||
queues: List[str]) -> None:
|
||||
in_queue: Queue,
|
||||
out_queue: Queue,
|
||||
queues: list[str]) -> None:
|
||||
""" Add the queues
|
||||
in_queue and out_queue should be previously created queue manager queues.
|
||||
queues should be a list of queue names """
|
||||
|
@ -533,8 +526,8 @@ class Extractor():
|
|||
def _add_thread(self,
|
||||
name: str,
|
||||
function: Callable[[BatchType], BatchType],
|
||||
in_queue: "Queue",
|
||||
out_queue: "Queue") -> None:
|
||||
in_queue: Queue,
|
||||
out_queue: Queue) -> None:
|
||||
""" Add a MultiThread thread to self._threads """
|
||||
logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)",
|
||||
name, function, in_queue, out_queue)
|
||||
|
@ -546,8 +539,8 @@ class Extractor():
|
|||
logger.debug("Added thread: %s", name)
|
||||
|
||||
def _obtain_batch_item(self, function: Callable[[BatchType], BatchType],
|
||||
in_queue: "Queue",
|
||||
out_queue: "Queue") -> Optional[BatchType]:
|
||||
in_queue: Queue,
|
||||
out_queue: Queue) -> BatchType | None:
|
||||
""" Obtain the batch item from the in queue for the current process.
|
||||
|
||||
Parameters
|
||||
|
@ -564,7 +557,7 @@ class Extractor():
|
|||
:class:`ExtractorBatch` or ``None``
|
||||
The batch, if one exists, or ``None`` if queue is exhausted
|
||||
"""
|
||||
batch: Union[Literal["EOF"], BatchType, ExtractMedia]
|
||||
batch: T.Literal["EOF"] | BatchType | ExtractMedia
|
||||
if function.__name__ == "_process_input": # Process input items to batches
|
||||
exhausted, batch = self.get_batch(in_queue)
|
||||
if exhausted:
|
||||
|
@ -585,8 +578,8 @@ class Extractor():
|
|||
|
||||
def _thread_process(self,
|
||||
function: Callable[[BatchType], BatchType],
|
||||
in_queue: "Queue",
|
||||
out_queue: "Queue") -> None:
|
||||
in_queue: Queue,
|
||||
out_queue: Queue) -> None:
|
||||
""" Perform a plugin function in a thread
|
||||
|
||||
Parameters
|
||||
|
@ -629,7 +622,7 @@ class Extractor():
|
|||
out_queue.put("EOF")
|
||||
|
||||
# <<< QUEUE METHODS >>> #
|
||||
def _get_item(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia, BatchType]:
|
||||
def _get_item(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia | BatchType:
|
||||
""" Yield one item from a queue """
|
||||
item = queue.get()
|
||||
if isinstance(item, ExtractMedia):
|
||||
|
|
|
@ -12,12 +12,12 @@ For each source item, the plugin must pass a dict to finalize containing:
|
|||
>>> "landmarks": [list of 68 point face landmarks]
|
||||
>>> "detected_faces": [<list of DetectedFace objects>]}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from time import sleep
|
||||
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -28,12 +28,8 @@ from lib.utils import FaceswapError
|
|||
from plugins.extract._base import BatchType, Extractor, ExtractMedia, ExtractorBatch
|
||||
from .processing import AlignedFilter, ReAlign
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from queue import Queue
|
||||
from lib.align import DetectedFace
|
||||
from lib.align.aligned_face import CenteringType
|
||||
|
@ -77,9 +73,9 @@ class AlignerBatch(ExtractorBatch):
|
|||
The masks used to filter out re-feed values for passing to the re-aligner.
|
||||
"""
|
||||
batch_id: int = 0
|
||||
detected_faces: List["DetectedFace"] = field(default_factory=list)
|
||||
detected_faces: list[DetectedFace] = field(default_factory=list)
|
||||
landmarks: np.ndarray = np.array([])
|
||||
refeeds: List[np.ndarray] = field(default_factory=list)
|
||||
refeeds: list[np.ndarray] = field(default_factory=list)
|
||||
second_pass: bool = False
|
||||
second_pass_masks: np.ndarray = np.array([])
|
||||
|
||||
|
@ -142,11 +138,11 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
git_model_id: Optional[int] = None,
|
||||
model_filename: Optional[str] = None,
|
||||
configfile: Optional[str] = None,
|
||||
git_model_id: int | None = None,
|
||||
model_filename: str | None = None,
|
||||
configfile: str | None = None,
|
||||
instance: int = 0,
|
||||
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None,
|
||||
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
|
||||
re_feed: int = 0,
|
||||
re_align: bool = False,
|
||||
disable_filter: bool = False,
|
||||
|
@ -160,9 +156,9 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
instance=instance,
|
||||
**kwargs)
|
||||
self._plugin_type = "align"
|
||||
self.realign_centering: "CenteringType" = "face" # overide for plugin specific centering
|
||||
self.realign_centering: CenteringType = "face" # overide for plugin specific centering
|
||||
self._eof_seen = False
|
||||
self._normalize_method: Optional[Literal["clahe", "hist", "mean"]] = None
|
||||
self._normalize_method: T.Literal["clahe", "hist", "mean"] | None = None
|
||||
self._re_feed = re_feed
|
||||
self._filter = AlignedFilter(feature_filter=self.config["aligner_features"],
|
||||
min_scale=self.config["aligner_min_scale"],
|
||||
|
@ -181,8 +177,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def set_normalize_method(self,
|
||||
method: Optional[Literal["none", "clahe", "hist", "mean"]]) -> None:
|
||||
def set_normalize_method(self, method: T.Literal["none", "clahe", "hist", "mean"] | None
|
||||
) -> None:
|
||||
""" Set the normalization method for feeding faces into the aligner.
|
||||
|
||||
Parameters
|
||||
|
@ -191,14 +187,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
The normalization method to apply to faces prior to feeding into the model
|
||||
"""
|
||||
method = None if method is None or method.lower() == "none" else method
|
||||
self._normalize_method = cast(Optional[Literal["clahe", "hist", "mean"]], method)
|
||||
self._normalize_method = T.cast(T.Literal["clahe", "hist", "mean"] | None, method)
|
||||
|
||||
def initialize(self, *args, **kwargs) -> None:
|
||||
""" Add a call to add model input size to the re-aligner """
|
||||
self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering)
|
||||
super().initialize(*args, **kwargs)
|
||||
|
||||
def _handle_realigns(self, queue: "Queue") -> Optional[Tuple[bool, AlignerBatch]]:
|
||||
def _handle_realigns(self, queue: Queue) -> tuple[bool, AlignerBatch] | None:
|
||||
""" Handle any items waiting for a second pass through the aligner.
|
||||
|
||||
If EOF has been recieved and items are still being processed through the first pass
|
||||
|
@ -242,7 +238,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
return None
|
||||
|
||||
def get_batch(self, queue: "Queue") -> Tuple[bool, AlignerBatch]:
|
||||
def get_batch(self, queue: Queue) -> tuple[bool, AlignerBatch]:
|
||||
""" Get items for inputting into the aligner from the queue in batches
|
||||
|
||||
Items are returned from the ``queue`` in batches of
|
||||
|
@ -548,7 +544,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
"\n3) Enable 'Single Process' mode.")
|
||||
raise FaceswapError(msg) from err
|
||||
|
||||
def _process_refeeds(self, batch: AlignerBatch) -> List[AlignerBatch]:
|
||||
def _process_refeeds(self, batch: AlignerBatch) -> list[AlignerBatch]:
|
||||
""" Process the output for each selected re-feed
|
||||
|
||||
Parameters
|
||||
|
@ -562,7 +558,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
List of :class:`AlignerBatch` objects. Each object in the list contains the
|
||||
results for each selected re-feed
|
||||
"""
|
||||
retval: List[AlignerBatch] = []
|
||||
retval: list[AlignerBatch] = []
|
||||
if batch.second_pass:
|
||||
# Re-insert empty sub-patches for re-population in ReAlign for filtered out batches
|
||||
selected_idx = 0
|
||||
|
@ -605,8 +601,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
return retval
|
||||
|
||||
def _get_refeed_filter_masks(self,
|
||||
subbatches: List[AlignerBatch],
|
||||
original_masks: Optional[np.ndarray] = None) -> np.ndarray:
|
||||
subbatches: list[AlignerBatch],
|
||||
original_masks: np.ndarray | None = None) -> np.ndarray:
|
||||
""" Obtain the boolean mask array for masking out failed re-feed results if filter refeed
|
||||
has been selected
|
||||
|
||||
|
@ -663,7 +659,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
landmarks.shape)
|
||||
return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32")
|
||||
|
||||
def _process_output_first_pass(self, subbatches: List[AlignerBatch]) -> Tuple[np.ndarray,
|
||||
def _process_output_first_pass(self, subbatches: list[AlignerBatch]) -> tuple[np.ndarray,
|
||||
np.ndarray]:
|
||||
""" Process the output from the aligner if this is the first or only pass.
|
||||
|
||||
|
@ -696,7 +692,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method
|
|||
return all_landmarks, masks
|
||||
|
||||
def _process_output_second_pass(self,
|
||||
subbatches: List[AlignerBatch],
|
||||
subbatches: list[AlignerBatch],
|
||||
masks: np.ndarray) -> np.ndarray:
|
||||
""" Process the output from the aligner if this is the first or only pass.
|
||||
|
||||
|
|
|
@ -1,21 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Processing methods for aligner plugins """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lib.align import AlignedFace
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align import DetectedFace
|
||||
from .aligner import AlignerBatch
|
||||
from lib.align.aligned_face import CenteringType
|
||||
|
@ -72,16 +67,16 @@ class AlignedFilter():
|
|||
min_scale > 0.0 or
|
||||
distance > 0.0 or
|
||||
roll > 0.0)
|
||||
self._counts: Dict[str, int] = dict(features=0,
|
||||
min_scale=0,
|
||||
max_scale=0,
|
||||
distance=0,
|
||||
roll=0)
|
||||
self._counts: dict[str, int] = {"features": 0,
|
||||
"min_scale": 0,
|
||||
"max_scale": 0,
|
||||
"distance": 0,
|
||||
"roll": 0}
|
||||
logger.debug("Initialized %s: ", self.__class__.__name__)
|
||||
|
||||
def _scale_test(self,
|
||||
face: AlignedFace,
|
||||
minimum_dimension: int) -> Optional[Literal["min", "max"]]:
|
||||
minimum_dimension: int) -> T.Literal["min", "max"] | None:
|
||||
""" Test if a face is below or above the min/max size thresholds. Returns as soon as a test
|
||||
fails.
|
||||
|
||||
|
@ -116,9 +111,9 @@ class AlignedFilter():
|
|||
|
||||
def _handle_filtered(self,
|
||||
key: str,
|
||||
face: "DetectedFace",
|
||||
faces: List["DetectedFace"],
|
||||
sub_folders: List[Optional[str]],
|
||||
face: DetectedFace,
|
||||
faces: list[DetectedFace],
|
||||
sub_folders: list[str | None],
|
||||
sub_folder_index: int) -> None:
|
||||
""" Add the filtered item to the filter counts.
|
||||
|
||||
|
@ -145,8 +140,8 @@ class AlignedFilter():
|
|||
faces.append(face)
|
||||
sub_folders[sub_folder_index] = f"_align_filt_{key}"
|
||||
|
||||
def __call__(self, faces: List["DetectedFace"], minimum_dimension: int
|
||||
) -> Tuple[List["DetectedFace"], List[Optional[str]]]:
|
||||
def __call__(self, faces: list[DetectedFace], minimum_dimension: int
|
||||
) -> tuple[list[DetectedFace], list[str | None]]:
|
||||
""" Apply the filter to the incoming batch
|
||||
|
||||
Parameters
|
||||
|
@ -165,11 +160,11 @@ class AlignedFilter():
|
|||
List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones``
|
||||
and sub folder names corresponding the filtered face location
|
||||
"""
|
||||
sub_folders: List[Optional[str]] = [None for _ in range(len(faces))]
|
||||
sub_folders: list[str | None] = [None for _ in range(len(faces))]
|
||||
if not self._active:
|
||||
return faces, sub_folders
|
||||
|
||||
retval: List["DetectedFace"] = []
|
||||
retval: list[DetectedFace] = []
|
||||
for idx, face in enumerate(faces):
|
||||
aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face")
|
||||
|
||||
|
@ -194,8 +189,8 @@ class AlignedFilter():
|
|||
return retval, sub_folders
|
||||
|
||||
def filtered_mask(self,
|
||||
batch: "AlignerBatch",
|
||||
skip: Optional[Union[np.ndarray, List[int]]] = None) -> np.ndarray:
|
||||
batch: AlignerBatch,
|
||||
skip: np.ndarray | list[int] | None = None) -> np.ndarray:
|
||||
""" Obtain a list of boolean values for the given batch indicating whether they pass the
|
||||
filter test.
|
||||
|
||||
|
@ -262,13 +257,14 @@ class ReAlign():
|
|||
self._active = active
|
||||
self._do_refeeds = do_refeeds
|
||||
self._do_filter = do_filter
|
||||
self._centering: "CenteringType" = "face"
|
||||
self._centering: CenteringType = "face"
|
||||
self._size = 0
|
||||
self._tracked_lock = Lock()
|
||||
self._tracked_batchs: Dict[int, Dict[Literal["filtered_landmarks"], List[np.ndarray]]] = {}
|
||||
self._tracked_batchs: dict[int,
|
||||
dict[T.Literal["filtered_landmarks"], list[np.ndarray]]] = {}
|
||||
# TODO. Probably does not need to be a list, just alignerbatch
|
||||
self._queue_lock = Lock()
|
||||
self._queued: List["AlignerBatch"] = []
|
||||
self._queued: list[AlignerBatch] = []
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
|
@ -301,7 +297,7 @@ class ReAlign():
|
|||
with self._tracked_lock:
|
||||
return bool(self._tracked_batchs)
|
||||
|
||||
def set_input_size_and_centering(self, input_size: int, centering: "CenteringType") -> None:
|
||||
def set_input_size_and_centering(self, input_size: int, centering: CenteringType) -> None:
|
||||
""" Set the input size of the loaded plugin once the model has been loaded
|
||||
|
||||
Parameters
|
||||
|
@ -344,7 +340,7 @@ class ReAlign():
|
|||
with self._tracked_lock:
|
||||
del self._tracked_batchs[batch_id]
|
||||
|
||||
def add_batch(self, batch: "AlignerBatch") -> None:
|
||||
def add_batch(self, batch: AlignerBatch) -> None:
|
||||
""" Add first pass alignments to the queue for picking up for re-alignment, update their
|
||||
:attr:`second_pass` attribute to ``True`` and clear attributes not required.
|
||||
|
||||
|
@ -362,7 +358,7 @@ class ReAlign():
|
|||
batch.data = []
|
||||
self._queued.append(batch)
|
||||
|
||||
def get_batch(self) -> "AlignerBatch":
|
||||
def get_batch(self) -> AlignerBatch:
|
||||
""" Retrieve the next batch currently queued for re-alignment
|
||||
|
||||
Returns
|
||||
|
@ -376,7 +372,7 @@ class ReAlign():
|
|||
retval.filename)
|
||||
return retval
|
||||
|
||||
def process_batch(self, batch: "AlignerBatch") -> List[np.ndarray]:
|
||||
def process_batch(self, batch: AlignerBatch) -> list[np.ndarray]:
|
||||
""" Pre process a batch object for re-aligning through the aligner.
|
||||
|
||||
Parameters
|
||||
|
@ -391,8 +387,8 @@ class ReAlign():
|
|||
"""
|
||||
logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined]
|
||||
batch.filename, [b.shape for b in batch.landmarks])
|
||||
retval: List[np.ndarray] = []
|
||||
filtered_landmarks: List[np.ndarray] = []
|
||||
retval: list[np.ndarray] = []
|
||||
filtered_landmarks: list[np.ndarray] = []
|
||||
for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks):
|
||||
if not np.all(masks): # At least one face has not already been filtered
|
||||
aligned_faces = [AlignedFace(lms,
|
||||
|
@ -415,7 +411,7 @@ class ReAlign():
|
|||
batch.landmarks = np.array([]) # Clear the old landmarks
|
||||
return retval
|
||||
|
||||
def _transform_to_frame(self, batch: "AlignerBatch") -> np.ndarray:
|
||||
def _transform_to_frame(self, batch: AlignerBatch) -> np.ndarray:
|
||||
""" Transform the predicted landmarks from the aligned face image back into frame
|
||||
co-ordinates
|
||||
|
||||
|
@ -430,14 +426,14 @@ class ReAlign():
|
|||
:class:`numpy.ndarray`
|
||||
The landmarks transformed to frame space
|
||||
"""
|
||||
faces: List[AlignedFace] = batch.data[0]["aligned_faces"]
|
||||
faces: list[AlignedFace] = batch.data[0]["aligned_faces"]
|
||||
retval = np.array([aligned.transform_points(landmarks, invert=True)
|
||||
for landmarks, aligned in zip(batch.landmarks, faces)])
|
||||
logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined]
|
||||
"new max: %s", batch.landmarks.max(), retval.max())
|
||||
return retval
|
||||
|
||||
def _re_insert_filtered(self, batch: "AlignerBatch", masks: np.ndarray) -> np.ndarray:
|
||||
def _re_insert_filtered(self, batch: AlignerBatch, masks: np.ndarray) -> np.ndarray:
|
||||
""" Re-insert landmarks that were filtered out from the re-align process back into the
|
||||
landmark results
|
||||
|
||||
|
@ -473,7 +469,7 @@ class ReAlign():
|
|||
|
||||
return retval
|
||||
|
||||
def process_output(self, subbatches: List["AlignerBatch"], batch_masks: np.ndarray) -> None:
|
||||
def process_output(self, subbatches: list[AlignerBatch], batch_masks: np.ndarray) -> None:
|
||||
""" Process the output from the re-align pass.
|
||||
|
||||
- Transform landmarks from aligned face space to face space
|
||||
|
|
|
@ -23,15 +23,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import cast, List, Tuple, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ._base import Aligner, AlignerBatch, BatchType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.detected_face import DetectedFace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -89,9 +90,9 @@ class Align(Aligner):
|
|||
assert isinstance(batch, AlignerBatch)
|
||||
lfaces, roi, offsets = self.align_image(batch)
|
||||
batch.feed = np.array(lfaces)[..., :3]
|
||||
batch.data.append(dict(roi=roi, offsets=offsets))
|
||||
batch.data.append({"roi": roi, "offsets": offsets})
|
||||
|
||||
def _get_box_and_offset(self, face: "DetectedFace") -> Tuple[List[int], int]:
|
||||
def _get_box_and_offset(self, face: DetectedFace) -> tuple[list[int], int]:
|
||||
"""Obtain the bounding box and offset from a detected face.
|
||||
|
||||
|
||||
|
@ -108,17 +109,17 @@ class Align(Aligner):
|
|||
The offset of the box (difference between half width vs height)
|
||||
"""
|
||||
|
||||
box = cast(List[int], [face.left,
|
||||
face.top,
|
||||
face.right,
|
||||
face.bottom])
|
||||
diff_height_width = cast(int, face.height) - cast(int, face.width)
|
||||
box = T.cast(list[int], [face.left,
|
||||
face.top,
|
||||
face.right,
|
||||
face.bottom])
|
||||
diff_height_width = T.cast(int, face.height) - T.cast(int, face.width)
|
||||
offset = int(abs(diff_height_width / 2))
|
||||
return box, offset
|
||||
|
||||
def align_image(self, batch: AlignerBatch) -> Tuple[List[np.ndarray],
|
||||
List[List[int]],
|
||||
List[Tuple[int, int]]]:
|
||||
def align_image(self, batch: AlignerBatch) -> tuple[list[np.ndarray],
|
||||
list[list[int]],
|
||||
list[tuple[int, int]]]:
|
||||
""" Align the incoming image for prediction
|
||||
|
||||
Parameters
|
||||
|
@ -159,8 +160,8 @@ class Align(Aligner):
|
|||
|
||||
@classmethod
|
||||
def move_box(cls,
|
||||
box: List[int],
|
||||
offset: Tuple[int, int]) -> List[int]:
|
||||
box: list[int],
|
||||
offset: tuple[int, int]) -> list[int]:
|
||||
"""Move the box to direction specified by vector offset
|
||||
|
||||
Parameters
|
||||
|
@ -182,7 +183,7 @@ class Align(Aligner):
|
|||
return [left, top, right, bottom]
|
||||
|
||||
@staticmethod
|
||||
def get_square_box(box: List[int]) -> List[int]:
|
||||
def get_square_box(box: list[int]) -> list[int]:
|
||||
"""Get a square box out of the given box, by expanding it.
|
||||
|
||||
Parameters
|
||||
|
@ -226,7 +227,7 @@ class Align(Aligner):
|
|||
return [left, top, right, bottom]
|
||||
|
||||
@classmethod
|
||||
def pad_image(cls, box: List[int], image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
def pad_image(cls, box: list[int], image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
|
||||
"""Pad image if face-box falls outside of boundaries
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
Code adapted and modified from:
|
||||
https://github.com/1adrianb/face-alignment
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import cast, List, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -12,7 +13,7 @@ import numpy as np
|
|||
from lib.model.session import KSession
|
||||
from ._base import Aligner, AlignerBatch, BatchType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align import DetectedFace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -76,10 +77,10 @@ class Align(Aligner):
|
|||
logger.trace("Aligning faces around center") # type:ignore[attr-defined]
|
||||
center_scale = self.get_center_scale(batch.detected_faces)
|
||||
batch.feed = np.array(self.crop(batch, center_scale))[..., :3]
|
||||
batch.data.append(dict(center_scale=center_scale))
|
||||
batch.data.append({"center_scale": center_scale})
|
||||
logger.trace("Aligned image around center") # type:ignore[attr-defined]
|
||||
|
||||
def get_center_scale(self, detected_faces: List["DetectedFace"]) -> np.ndarray:
|
||||
def get_center_scale(self, detected_faces: list[DetectedFace]) -> np.ndarray:
|
||||
""" Get the center and set scale of bounding box
|
||||
|
||||
Parameters
|
||||
|
@ -95,11 +96,11 @@ class Align(Aligner):
|
|||
logger.trace("Calculating center and scale") # type:ignore[attr-defined]
|
||||
center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32')
|
||||
for index, face in enumerate(detected_faces):
|
||||
x_center = (cast(int, face.left) + face.right) / 2.0
|
||||
y_center = (cast(int, face.top) + face.bottom) / 2.0 - cast(int, face.height) * 0.12
|
||||
scale = (cast(int, face.width) + cast(int, face.height)) * self.reference_scale
|
||||
center_scale[index, :, 0] = np.full(68, x_center, dtype='float32')
|
||||
center_scale[index, :, 1] = np.full(68, y_center, dtype='float32')
|
||||
x_ctr = (T.cast(int, face.left) + face.right) / 2.0
|
||||
y_ctr = (T.cast(int, face.top) + face.bottom) / 2.0 - T.cast(int, face.height) * 0.12
|
||||
scale = (T.cast(int, face.width) + T.cast(int, face.height)) * self.reference_scale
|
||||
center_scale[index, :, 0] = np.full(68, x_ctr, dtype='float32')
|
||||
center_scale[index, :, 1] = np.full(68, y_ctr, dtype='float32')
|
||||
center_scale[index, :, 2] = np.full(68, scale, dtype='float32')
|
||||
logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined]
|
||||
return center_scale
|
||||
|
@ -144,7 +145,7 @@ class Align(Aligner):
|
|||
dsize=(self.input_size, self.input_size),
|
||||
interpolation=interp)
|
||||
|
||||
def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> List[np.ndarray]:
|
||||
def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> list[np.ndarray]:
|
||||
""" Crop image around the center point
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -15,9 +15,11 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
|
|||
|
||||
>>> face = self._to_detected_face(<face left>, <face top>, <face right>, <face bottom>)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -30,7 +32,8 @@ from lib.utils import FaceswapError
|
|||
from plugins.extract._base import BatchType, Extractor, ExtractorBatch
|
||||
from plugins.extract.pipeline import ExtractMedia
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from queue import Queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -53,10 +56,10 @@ class DetectorBatch(ExtractorBatch):
|
|||
initial_feed: :class:`numpy.ndarray`
|
||||
Used to hold the initial :attr:`feed` when rotate images is enabled
|
||||
"""
|
||||
detected_faces: List[List["DetectedFace"]] = field(default_factory=list)
|
||||
rotation_matrix: List[np.ndarray] = field(default_factory=list)
|
||||
scale: List[float] = field(default_factory=list)
|
||||
pad: List[Tuple[int, int]] = field(default_factory=list)
|
||||
detected_faces: list[list["DetectedFace"]] = field(default_factory=list)
|
||||
rotation_matrix: list[np.ndarray] = field(default_factory=list)
|
||||
scale: list[float] = field(default_factory=list)
|
||||
pad: list[tuple[int, int]] = field(default_factory=list)
|
||||
initial_feed: np.ndarray = np.array([])
|
||||
|
||||
|
||||
|
@ -95,11 +98,11 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
git_model_id: Optional[int] = None,
|
||||
model_filename: Optional[Union[str, List[str]]] = None,
|
||||
configfile: Optional[str] = None,
|
||||
git_model_id: int | None = None,
|
||||
model_filename: str | list[str] | None = None,
|
||||
configfile: str | None = None,
|
||||
instance: int = 0,
|
||||
rotation: Optional[str] = None,
|
||||
rotation: str | None = None,
|
||||
min_size: int = 0,
|
||||
**kwargs) -> None:
|
||||
logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__,
|
||||
|
@ -117,7 +120,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
logger.debug("Initialized _base %s", self.__class__.__name__)
|
||||
|
||||
# <<< QUEUE METHODS >>> #
|
||||
def get_batch(self, queue: "Queue") -> Tuple[bool, DetectorBatch]:
|
||||
def get_batch(self, queue: Queue) -> tuple[bool, DetectorBatch]:
|
||||
""" Get items for inputting to the detector plugin in batches
|
||||
|
||||
Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted
|
||||
|
@ -271,7 +274,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
""" Wrap models predict function in rotations """
|
||||
assert isinstance(batch, DetectorBatch)
|
||||
batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))]
|
||||
found_faces: List[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))]
|
||||
found_faces: list[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))]
|
||||
for angle in self.rotation:
|
||||
# Rotate the batch and insert placeholders for already found faces
|
||||
self._rotate_batch(batch, angle)
|
||||
|
@ -301,9 +304,9 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
"degrees",
|
||||
angle)
|
||||
|
||||
found_faces = cast(List[np.ndarray], ([face if not found.any() else found
|
||||
for face, found in zip(batch.prediction,
|
||||
found_faces)]))
|
||||
found_faces = T.cast(list[np.ndarray], ([face if not found.any() else found
|
||||
for face, found in zip(batch.prediction,
|
||||
found_faces)]))
|
||||
|
||||
if all(face.any() for face in found_faces):
|
||||
logger.trace("Faces found for all images") # type:ignore[attr-defined]
|
||||
|
@ -317,7 +320,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
# <<< DETECTION IMAGE COMPILATION METHODS >>> #
|
||||
def _compile_detection_image(self, item: ExtractMedia
|
||||
) -> Tuple[np.ndarray, float, Tuple[int, int]]:
|
||||
) -> tuple[np.ndarray, float, tuple[int, int]]:
|
||||
""" Compile the detection image for feeding into the model
|
||||
|
||||
Parameters
|
||||
|
@ -345,7 +348,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
image.shape, scale, pad)
|
||||
return image, scale, pad
|
||||
|
||||
def _set_scale(self, image_size: Tuple[int, int]) -> float:
|
||||
def _set_scale(self, image_size: tuple[int, int]) -> float:
|
||||
""" Set the scale factor for incoming image
|
||||
|
||||
Parameters
|
||||
|
@ -362,7 +365,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined]
|
||||
return scale
|
||||
|
||||
def _set_padding(self, image_size: Tuple[int, int], scale: float) -> Tuple[int, int]:
|
||||
def _set_padding(self, image_size: tuple[int, int], scale: float) -> tuple[int, int]:
|
||||
""" Set the image padding for non-square images
|
||||
|
||||
Parameters
|
||||
|
@ -382,7 +385,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
return pad_left, pad_top
|
||||
|
||||
@staticmethod
|
||||
def _scale_image(image: np.ndarray, image_size: Tuple[int, int], scale: float) -> np.ndarray:
|
||||
def _scale_image(image: np.ndarray, image_size: tuple[int, int], scale: float) -> np.ndarray:
|
||||
""" Scale the image and optional pad to given size
|
||||
|
||||
Parameters
|
||||
|
@ -439,8 +442,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
return image
|
||||
|
||||
# <<< FINALIZE METHODS >>> #
|
||||
def _remove_zero_sized_faces(self, batch_faces: List[List[DetectedFace]]
|
||||
) -> List[List[DetectedFace]]:
|
||||
def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]]
|
||||
) -> list[list[DetectedFace]]:
|
||||
""" Remove items from batch_faces where detected face is of zero size or face falls
|
||||
entirely outside of image
|
||||
|
||||
|
@ -463,8 +466,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore
|
||||
return retval
|
||||
|
||||
def _filter_small_faces(self, detected_faces: List[List[DetectedFace]]
|
||||
) -> List[List[DetectedFace]]:
|
||||
def _filter_small_faces(self, detected_faces: list[list[DetectedFace]]
|
||||
) -> list[list[DetectedFace]]:
|
||||
""" Filter out any faces smaller than the min size threshold
|
||||
|
||||
Parameters
|
||||
|
@ -493,7 +496,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
# <<< IMAGE ROTATION METHODS >>> #
|
||||
@staticmethod
|
||||
def _get_rotation_angles(rotation: Optional[str]) -> List[int]:
|
||||
def _get_rotation_angles(rotation: str | None) -> list[int]:
|
||||
""" Set the rotation angles.
|
||||
|
||||
Parameters
|
||||
|
@ -544,8 +547,8 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
batch.initial_feed = batch.feed.copy()
|
||||
return
|
||||
|
||||
feeds: List[np.ndarray] = []
|
||||
rotmats: List[np.ndarray] = []
|
||||
feeds: list[np.ndarray] = []
|
||||
rotmats: list[np.ndarray] = []
|
||||
for img, faces, rotmat in zip(batch.initial_feed,
|
||||
batch.prediction,
|
||||
batch.rotation_matrix):
|
||||
|
@ -605,7 +608,7 @@ class Detector(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
def _rotate_image_by_angle(self,
|
||||
image: np.ndarray,
|
||||
angle: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
angle: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Rotate an image by a given angle.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -34,7 +34,7 @@ class Detect(Detector):
|
|||
self.kwargs = self._validate_kwargs()
|
||||
self.color_format = "RGB"
|
||||
|
||||
def _validate_kwargs(self) -> T.Dict[str, T.Union[int, float, T.List[float]]]:
|
||||
def _validate_kwargs(self) -> dict[str, int | float | list[float]]:
|
||||
""" Validate that config options are correct. If not reset to default """
|
||||
valid = True
|
||||
threshold = [self.config["threshold_1"],
|
||||
|
@ -164,7 +164,7 @@ class PNet(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
cpu_mode: bool,
|
||||
input_size: int,
|
||||
min_size: int,
|
||||
|
@ -185,10 +185,10 @@ class PNet(KSession):
|
|||
self._pnet_scales = self._calculate_scales(min_size, factor)
|
||||
self._pnet_sizes = [(int(input_size * scale), int(input_size * scale))
|
||||
for scale in self._pnet_scales]
|
||||
self._pnet_input: T.Optional[T.List[np.ndarray]] = None
|
||||
self._pnet_input: list[np.ndarray] | None = None
|
||||
|
||||
@staticmethod
|
||||
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
|
||||
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
|
||||
""" Keras P-Network Definition for MTCNN """
|
||||
input_ = Input(shape=(None, None, 3))
|
||||
var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_)
|
||||
|
@ -204,7 +204,7 @@ class PNet(KSession):
|
|||
|
||||
def _calculate_scales(self,
|
||||
minsize: int,
|
||||
factor: float) -> T.List[float]:
|
||||
factor: float) -> list[float]:
|
||||
""" Calculate multi-scale
|
||||
|
||||
Parameters
|
||||
|
@ -231,7 +231,7 @@ class PNet(KSession):
|
|||
logger.trace(scales) # type:ignore
|
||||
return scales
|
||||
|
||||
def __call__(self, images: np.ndarray) -> T.List[np.ndarray]:
|
||||
def __call__(self, images: np.ndarray) -> list[np.ndarray]:
|
||||
""" first stage - fast proposal network (p-net) to obtain face candidates
|
||||
|
||||
Parameters
|
||||
|
@ -245,8 +245,8 @@ class PNet(KSession):
|
|||
List of face candidates from P-Net
|
||||
"""
|
||||
batch_size = images.shape[0]
|
||||
rectangles: T.List[T.List[T.List[T.Union[int, float]]]] = [[] for _ in range(batch_size)]
|
||||
scores: T.List[T.List[np.ndarray]] = [[] for _ in range(batch_size)]
|
||||
rectangles: list[list[list[int | float]]] = [[] for _ in range(batch_size)]
|
||||
scores: list[list[np.ndarray]] = [[] for _ in range(batch_size)]
|
||||
|
||||
if self._pnet_input is None:
|
||||
self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32")
|
||||
|
@ -278,7 +278,7 @@ class PNet(KSession):
|
|||
class_probabilities: np.ndarray,
|
||||
roi: np.ndarray,
|
||||
size: int,
|
||||
scale: float) -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
scale: float) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Detect face position and calibrate bounding box on 12net feature map(matrix version)
|
||||
|
||||
Parameters
|
||||
|
@ -344,7 +344,7 @@ class RNet(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
cpu_mode: bool,
|
||||
input_size: int,
|
||||
threshold: float) -> None:
|
||||
|
@ -360,7 +360,7 @@ class RNet(KSession):
|
|||
self._threshold = threshold
|
||||
|
||||
@staticmethod
|
||||
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
|
||||
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
|
||||
""" Keras R-Network Definition for MTCNN """
|
||||
input_ = Input(shape=(24, 24, 3))
|
||||
var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_)
|
||||
|
@ -383,8 +383,8 @@ class RNet(KSession):
|
|||
|
||||
def __call__(self,
|
||||
images: np.ndarray,
|
||||
rectangle_batch: T.List[np.ndarray],
|
||||
) -> T.List[np.ndarray]:
|
||||
rectangle_batch: list[np.ndarray],
|
||||
) -> list[np.ndarray]:
|
||||
""" second stage - refinement of face candidates with r-net
|
||||
|
||||
Parameters
|
||||
|
@ -399,7 +399,7 @@ class RNet(KSession):
|
|||
List
|
||||
List of :class:`numpy.ndarray` refined face candidates from R-Net
|
||||
"""
|
||||
ret: T.List[np.ndarray] = []
|
||||
ret: list[np.ndarray] = []
|
||||
for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)):
|
||||
if not np.any(rectangles):
|
||||
ret.append(np.array([]))
|
||||
|
@ -474,7 +474,7 @@ class ONet(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
cpu_mode: bool,
|
||||
input_size: int,
|
||||
threshold: float) -> None:
|
||||
|
@ -490,7 +490,7 @@ class ONet(KSession):
|
|||
self._threshold = threshold
|
||||
|
||||
@staticmethod
|
||||
def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
|
||||
def model_definition() -> tuple[list[Tensor], list[Tensor]]:
|
||||
""" Keras O-Network for MTCNN """
|
||||
input_ = Input(shape=(48, 48, 3))
|
||||
var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_)
|
||||
|
@ -516,8 +516,8 @@ class ONet(KSession):
|
|||
|
||||
def __call__(self,
|
||||
images: np.ndarray,
|
||||
rectangle_batch: T.List[np.ndarray]
|
||||
) -> T.List[T.Tuple[np.ndarray, np.ndarray]]:
|
||||
rectangle_batch: list[np.ndarray]
|
||||
) -> list[tuple[np.ndarray, np.ndarray]]:
|
||||
""" Third stage - further refinement and facial landmarks positions with o-net
|
||||
|
||||
Parameters
|
||||
|
@ -532,7 +532,7 @@ class ONet(KSession):
|
|||
List
|
||||
List of refined final candidates, scores and landmark points from O-Net
|
||||
"""
|
||||
ret: T.List[T.Tuple[np.ndarray, np.ndarray]] = []
|
||||
ret: list[tuple[np.ndarray, np.ndarray]] = []
|
||||
for idx, rectangles in enumerate(rectangle_batch):
|
||||
if not np.any(rectangles):
|
||||
ret.append((np.empty((0, 5)), np.empty(0)))
|
||||
|
@ -552,7 +552,7 @@ class ONet(KSession):
|
|||
def _filter_face_48net(self, class_probabilities: np.ndarray,
|
||||
roi: np.ndarray,
|
||||
points: np.ndarray,
|
||||
rectangles: np.ndarray) -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
rectangles: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
""" Filter face position and calibrate bounding box on 12net's output
|
||||
|
||||
Parameters
|
||||
|
@ -623,13 +623,13 @@ class MTCNN(): # pylint: disable=too-few-public-methods
|
|||
Default: `0.709`
|
||||
"""
|
||||
def __init__(self,
|
||||
model_path: T.List[str],
|
||||
model_path: list[str],
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
cpu_mode: bool,
|
||||
input_size: int = 640,
|
||||
minsize: int = 20,
|
||||
threshold: T.Optional[T.List[float]] = None,
|
||||
threshold: list[float] | None = None,
|
||||
factor: float = 0.709) -> None:
|
||||
logger.debug("Initializing: %s: (model_path: '%s', allow_growth: %s, exclude_gpus: %s, "
|
||||
"input_size: %s, minsize: %s, threshold: %s, factor: %s)",
|
||||
|
@ -660,7 +660,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
|
|||
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def detect_faces(self, batch: np.ndarray) -> T.Tuple[np.ndarray, T.Tuple[np.ndarray]]:
|
||||
def detect_faces(self, batch: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray]]:
|
||||
"""Detects faces in an image, and returns bounding boxes and points for them.
|
||||
|
||||
Parameters
|
||||
|
@ -684,7 +684,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods
|
|||
def nms(rectangles: np.ndarray,
|
||||
scores: np.ndarray,
|
||||
threshold: float,
|
||||
method: str = "iom") -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
method: str = "iom") -> tuple[np.ndarray, np.ndarray]:
|
||||
""" apply non-maximum suppression on ROIs in same scale(matrix version)
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -125,10 +125,10 @@ class L2Norm(keras.layers.Layer):
|
|||
class SliceO2K(keras.layers.Layer):
|
||||
""" Custom Keras Slice layer generated by onnx2keras. """
|
||||
def __init__(self,
|
||||
starts: T.List[int],
|
||||
ends: T.List[int],
|
||||
axes: T.Optional[T.List[int]] = None,
|
||||
steps: T.Optional[T.List[int]] = None,
|
||||
starts: list[int],
|
||||
ends: list[int],
|
||||
axes: list[int] | None = None,
|
||||
steps: list[int] | None = None,
|
||||
**kwargs) -> None:
|
||||
self._starts = starts
|
||||
self._ends = ends
|
||||
|
@ -136,7 +136,7 @@ class SliceO2K(keras.layers.Layer):
|
|||
self._steps = steps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _get_slices(self, dimensions: int) -> T.List[T.Tuple[int, ...]]:
|
||||
def _get_slices(self, dimensions: int) -> list[tuple[int, ...]]:
|
||||
""" Obtain slices for the given number of dimensions.
|
||||
|
||||
Parameters
|
||||
|
@ -154,7 +154,7 @@ class SliceO2K(keras.layers.Layer):
|
|||
assert len(axes) == len(steps) == len(self._starts) == len(self._ends)
|
||||
return list(zip(axes, self._starts, self._ends, steps))
|
||||
|
||||
def compute_output_shape(self, input_shape: T.Tuple[int, ...]) -> T.Tuple[int, ...]:
|
||||
def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Computes the output shape of the layer.
|
||||
|
||||
Assumes that the layer will be built to match that input shape provided.
|
||||
|
@ -230,7 +230,7 @@ class S3fd(KSession):
|
|||
model_path: str,
|
||||
model_kwargs: dict,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
confidence: float) -> None:
|
||||
logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, "
|
||||
"exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path,
|
||||
|
@ -246,7 +246,7 @@ class S3fd(KSession):
|
|||
self.average_img = np.array([104.0, 117.0, 123.0])
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
def model_definition(self) -> T.Tuple[T.List[Tensor], T.List[Tensor]]:
|
||||
def model_definition(self) -> tuple[list[Tensor], list[Tensor]]:
|
||||
""" Keras S3FD Model Definition, adapted from FAN pytorch implementation. """
|
||||
input_ = Input(shape=(640, 640, 3))
|
||||
var_x = self.conv_block(input_, 64, 1, 2)
|
||||
|
@ -396,7 +396,7 @@ class S3fd(KSession):
|
|||
batch = batch - self.average_img
|
||||
return batch
|
||||
|
||||
def finalize_predictions(self, bounding_boxes_scales: T.List[np.ndarray]) -> np.ndarray:
|
||||
def finalize_predictions(self, bounding_boxes_scales: list[np.ndarray]) -> np.ndarray:
|
||||
""" Process the output from the model to obtain faces
|
||||
|
||||
Parameters
|
||||
|
@ -413,7 +413,7 @@ class S3fd(KSession):
|
|||
ret.append(finallist)
|
||||
return np.array(ret, dtype="object")
|
||||
|
||||
def _post_process(self, bboxlist: T.List[np.ndarray]) -> np.ndarray:
|
||||
def _post_process(self, bboxlist: list[np.ndarray]) -> np.ndarray:
|
||||
""" Perform post processing on output
|
||||
TODO: do this on the batch.
|
||||
"""
|
||||
|
|
|
@ -12,9 +12,11 @@ For each source item, the plugin must pass a dict to finalize containing:
|
|||
>>> {"filename": <filename of source frame>,
|
||||
>>> "detected_faces": <list of bounding box dicts from lib/plugins/extract/detect/_base>}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -25,7 +27,8 @@ from lib.align import AlignedFace, transform_image
|
|||
from lib.utils import FaceswapError
|
||||
from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from queue import Queue
|
||||
from lib.align import DetectedFace
|
||||
from lib.align.aligned_face import CenteringType
|
||||
|
@ -44,9 +47,9 @@ class MaskerBatch(ExtractorBatch):
|
|||
roi_masks: list
|
||||
The region of interest masks for the batch
|
||||
"""
|
||||
detected_faces: List["DetectedFace"] = field(default_factory=list)
|
||||
roi_masks: List[np.ndarray] = field(default_factory=list)
|
||||
feed_faces: List[AlignedFace] = field(default_factory=list)
|
||||
detected_faces: list[DetectedFace] = field(default_factory=list)
|
||||
roi_masks: list[np.ndarray] = field(default_factory=list)
|
||||
feed_faces: list[AlignedFace] = field(default_factory=list)
|
||||
|
||||
|
||||
class Masker(Extractor): # pylint:disable=abstract-method
|
||||
|
@ -77,9 +80,9 @@ class Masker(Extractor): # pylint:disable=abstract-method
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
git_model_id: Optional[int] = None,
|
||||
model_filename: Optional[str] = None,
|
||||
configfile: Optional[str] = None,
|
||||
git_model_id: int | None = None,
|
||||
model_filename: str | None = None,
|
||||
configfile: str | None = None,
|
||||
instance: int = 0,
|
||||
**kwargs) -> None:
|
||||
logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile)
|
||||
|
@ -93,11 +96,11 @@ class Masker(Extractor): # pylint:disable=abstract-method
|
|||
|
||||
self._plugin_type = "mask"
|
||||
self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-")
|
||||
self._storage_centering: "CenteringType" = "face" # Centering to store the mask at
|
||||
self._storage_centering: CenteringType = "face" # Centering to store the mask at
|
||||
self._storage_size = 128 # Size to store masks at. Leave this at default
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def get_batch(self, queue: "Queue") -> Tuple[bool, MaskerBatch]:
|
||||
def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]:
|
||||
""" Get items for inputting into the masker from the queue in batches
|
||||
|
||||
Items are returned from the ``queue`` in batches of
|
||||
|
|
|
@ -49,7 +49,7 @@ class Mask(Masker):
|
|||
# Separate storage for face and head masks
|
||||
self._storage_name = f"{self._storage_name}_{self._storage_centering}"
|
||||
|
||||
def _check_weights_selection(self, configfile: T.Optional[str]) -> T.Tuple[bool, int]:
|
||||
def _check_weights_selection(self, configfile: str | None) -> tuple[bool, int]:
|
||||
""" Check which weights have been selected.
|
||||
|
||||
This is required for passing along the correct file name for the corresponding weights
|
||||
|
@ -73,7 +73,7 @@ class Mask(Masker):
|
|||
version = 1 if not is_faceswap else 2 if config.get("include_hair") else 3
|
||||
return is_faceswap, version
|
||||
|
||||
def _get_segment_indices(self) -> T.List[int]:
|
||||
def _get_segment_indices(self) -> list[int]:
|
||||
""" Obtain the segment indices to include within the face mask area based on user
|
||||
configuration settings.
|
||||
|
||||
|
@ -163,7 +163,7 @@ class Mask(Masker):
|
|||
# SOFTWARE.
|
||||
|
||||
|
||||
_NAME_TRACKER: T.Set[str] = set()
|
||||
_NAME_TRACKER: set[str] = set()
|
||||
|
||||
|
||||
def _get_name(name: str, start_idx: int = 1) -> str:
|
||||
|
@ -554,7 +554,7 @@ class BiSeNet(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]],
|
||||
exclude_gpus: list[int] | None,
|
||||
input_size: int,
|
||||
num_classes: int,
|
||||
cpu_mode: bool) -> None:
|
||||
|
@ -569,7 +569,7 @@ class BiSeNet(KSession):
|
|||
self.define_model(self._model_definition)
|
||||
self.load_model_weights()
|
||||
|
||||
def _model_definition(self) -> T.Tuple[Tensor, T.List[Tensor]]:
|
||||
def _model_definition(self) -> tuple[Tensor, list[Tensor]]:
|
||||
""" Definition of the VGG Obstructed Model.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Components Mask for faceswap.py """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ._base import BatchType, Masker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.aligned_face import AlignedFace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -36,7 +37,7 @@ class Mask(Masker):
|
|||
|
||||
def predict(self, feed: np.ndarray) -> np.ndarray:
|
||||
""" Run model to get predictions """
|
||||
faces: List["AlignedFace"] = feed[1]
|
||||
faces: list[AlignedFace] = feed[1]
|
||||
feed = feed[0]
|
||||
for mask, face in zip(feed, faces):
|
||||
parts = self.parse_parts(np.array(face.landmarks))
|
||||
|
@ -51,7 +52,7 @@ class Mask(Masker):
|
|||
return
|
||||
|
||||
@staticmethod
|
||||
def parse_parts(landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
|
||||
def parse_parts(landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
|
||||
""" Component face hull mask """
|
||||
r_jaw = (landmarks[0:9], landmarks[17:18])
|
||||
l_jaw = (landmarks[8:17], landmarks[26:27])
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Extended Mask for faceswap.py """
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -9,7 +10,7 @@ from ._base import BatchType, Masker
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.aligned_face import AlignedFace
|
||||
|
||||
|
||||
|
@ -35,7 +36,7 @@ class Mask(Masker):
|
|||
|
||||
def predict(self, feed: np.ndarray) -> np.ndarray:
|
||||
""" Run model to get predictions """
|
||||
faces: List["AlignedFace"] = feed[1]
|
||||
faces: list[AlignedFace] = feed[1]
|
||||
feed = feed[0]
|
||||
for mask, face in zip(feed, faces):
|
||||
parts = self.parse_parts(np.array(face.landmarks))
|
||||
|
@ -78,7 +79,7 @@ class Mask(Masker):
|
|||
landmarks[17:22] = top_l + ((top_l - bot_l) // 2)
|
||||
landmarks[22:27] = top_r + ((top_r - bot_r) // 2)
|
||||
|
||||
def parse_parts(self, landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
|
||||
def parse_parts(self, landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]:
|
||||
""" Extended face hull mask """
|
||||
self._adjust_mask_top(landmarks)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ Model file sourced from...
|
|||
https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from lib.model.session import KSession
|
||||
|
@ -52,7 +52,7 @@ class Mask(Masker):
|
|||
def process_input(self, batch: BatchType) -> None:
|
||||
""" Compile the detected faces for prediction """
|
||||
assert isinstance(batch, MaskerBatch)
|
||||
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3]
|
||||
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
|
||||
for feed in batch.feed_faces], dtype="float32") / 255.0
|
||||
logger.trace("feed shape: %s", batch.feed.shape) # type: ignore
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ class VGGClear(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]]):
|
||||
exclude_gpus: list[int] | None):
|
||||
super().__init__("VGG Obstructed",
|
||||
model_path,
|
||||
allow_growth=allow_growth,
|
||||
|
@ -103,7 +103,7 @@ class VGGClear(KSession):
|
|||
self.load_model_weights()
|
||||
|
||||
@classmethod
|
||||
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]:
|
||||
def _model_definition(cls) -> tuple[Tensor, Tensor]:
|
||||
""" Definition of the VGG Obstructed Model.
|
||||
|
||||
Returns
|
||||
|
@ -210,7 +210,7 @@ class _ScorePool(): # pylint:disable=too-few-public-methods
|
|||
crop: tuple
|
||||
The amount of 2D cropping to apply. Tuple of `ints`
|
||||
"""
|
||||
def __init__(self, level: int, scale: float, crop: T.Tuple[int, int]):
|
||||
def __init__(self, level: int, scale: float, crop: tuple[int, int]):
|
||||
self._name = f"_pool{level}"
|
||||
self._cropping = (crop, crop)
|
||||
self._scale = scale
|
||||
|
|
|
@ -90,7 +90,7 @@ class VGGObstructed(KSession):
|
|||
def __init__(self,
|
||||
model_path: str,
|
||||
allow_growth: bool,
|
||||
exclude_gpus: T.Optional[T.List[int]]) -> None:
|
||||
exclude_gpus: list[int] | None) -> None:
|
||||
super().__init__("VGG Obstructed",
|
||||
model_path,
|
||||
allow_growth=allow_growth,
|
||||
|
@ -99,7 +99,7 @@ class VGGObstructed(KSession):
|
|||
self.load_model_weights()
|
||||
|
||||
@classmethod
|
||||
def _model_definition(cls) -> T.Tuple[Tensor, Tensor]:
|
||||
def _model_definition(cls) -> tuple[Tensor, Tensor]:
|
||||
""" Definition of the VGG Obstructed Model.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -8,10 +8,9 @@ together.
|
|||
This module sets up a pipeline for the extraction workflow, loading detect, align and mask
|
||||
plugins either in parallel or in series, giving easy access to input and output.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
import typing as T
|
||||
|
||||
import cv2
|
||||
|
||||
|
@ -20,13 +19,9 @@ from lib.queue_manager import EventQueue, queue_manager, QueueEmpty
|
|||
from lib.utils import get_backend
|
||||
from plugins.plugin_loader import PluginLoader
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from collections.abc import Generator
|
||||
from lib.align.alignments import PNGHeaderSourceDict
|
||||
from lib.align.detected_face import DetectedFace
|
||||
from plugins.extract._base import Extractor as PluginExtractor
|
||||
|
@ -102,16 +97,16 @@ class Extractor():
|
|||
:attr:`final_pass` to indicate to the caller which phase is being processed
|
||||
"""
|
||||
def __init__(self,
|
||||
detector: Optional[str],
|
||||
aligner: Optional[str],
|
||||
masker: Optional[Union[str, List[str]]],
|
||||
recognition: Optional[str] = None,
|
||||
configfile: Optional[str] = None,
|
||||
detector: str | None,
|
||||
aligner: str | None,
|
||||
masker: str | list[str] | None,
|
||||
recognition: str | None = None,
|
||||
configfile: str | None = None,
|
||||
multiprocess: bool = False,
|
||||
exclude_gpus: Optional[List[int]] = None,
|
||||
rotate_images: Optional[str] = None,
|
||||
exclude_gpus: list[int] | None = None,
|
||||
rotate_images: str | None = None,
|
||||
min_size: int = 0,
|
||||
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None,
|
||||
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None,
|
||||
re_feed: int = 0,
|
||||
re_align: bool = False,
|
||||
disable_filter: bool = False) -> None:
|
||||
|
@ -122,8 +117,9 @@ class Extractor():
|
|||
recognition, configfile, multiprocess, exclude_gpus, rotate_images, min_size,
|
||||
normalize_method, re_feed, re_align, disable_filter)
|
||||
self._instance = _get_instance()
|
||||
maskers = [cast(Optional[str],
|
||||
masker)] if not isinstance(masker, list) else cast(List[Optional[str]], masker)
|
||||
maskers = [T.cast(str | None,
|
||||
masker)] if not isinstance(masker, list) else T.cast(list[str | None],
|
||||
masker)
|
||||
self._flow = self._set_flow(detector, aligner, maskers, recognition)
|
||||
self._exclude_gpus = exclude_gpus
|
||||
# We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting
|
||||
|
@ -220,13 +216,13 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@property
|
||||
def aligner(self) -> "Aligner":
|
||||
def aligner(self) -> Aligner:
|
||||
""" The currently selected aligner plugin """
|
||||
assert self._align is not None
|
||||
return self._align
|
||||
|
||||
@property
|
||||
def recognition(self) -> "Identity":
|
||||
def recognition(self) -> Identity:
|
||||
""" The currently selected recognition plugin """
|
||||
assert self._recognition is not None
|
||||
return self._recognition
|
||||
|
@ -237,7 +233,7 @@ class Extractor():
|
|||
self._phase_index = 0
|
||||
|
||||
def set_batchsize(self,
|
||||
plugin_type: Literal["align", "detect"],
|
||||
plugin_type: T.Literal["align", "detect"],
|
||||
batchsize: int) -> None:
|
||||
""" Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`.
|
||||
|
||||
|
@ -311,7 +307,7 @@ class Extractor():
|
|||
|
||||
# <<< INTERNAL METHODS >>> #
|
||||
@property
|
||||
def _parallel_scaling(self) -> Dict[int, float]:
|
||||
def _parallel_scaling(self) -> dict[int, float]:
|
||||
""" dict: key is number of parallel plugins being loaded, value is the scaling factor that
|
||||
the total base vram for those plugins should be scaled by
|
||||
|
||||
|
@ -335,7 +331,7 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _vram_per_phase(self) -> Dict[str, float]:
|
||||
def _vram_per_phase(self) -> dict[str, float]:
|
||||
""" dict: The amount of vram required for each phase in :attr:`_flow`. """
|
||||
retval = {}
|
||||
for phase in self._flow:
|
||||
|
@ -359,7 +355,7 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _current_phase(self) -> List[str]:
|
||||
def _current_phase(self) -> list[str]:
|
||||
""" list: The current phase from :attr:`_phases` that is running through the extractor. """
|
||||
retval = self._phases[self._phase_index]
|
||||
logger.trace(retval) # type: ignore
|
||||
|
@ -384,7 +380,7 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _all_plugins(self) -> List["PluginExtractor"]:
|
||||
def _all_plugins(self) -> list[PluginExtractor]:
|
||||
""" Return list of all plugin objects in this pipeline """
|
||||
retval = []
|
||||
for phase in self._flow:
|
||||
|
@ -396,7 +392,7 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@property
|
||||
def _active_plugins(self) -> List["PluginExtractor"]:
|
||||
def _active_plugins(self) -> list[PluginExtractor]:
|
||||
""" Return the plugins that are currently active based on pass """
|
||||
retval = []
|
||||
for phase in self._current_phase:
|
||||
|
@ -407,10 +403,10 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@staticmethod
|
||||
def _set_flow(detector: Optional[str],
|
||||
aligner: Optional[str],
|
||||
masker: List[Optional[str]],
|
||||
recognition: Optional[str]) -> List[str]:
|
||||
def _set_flow(detector: str | None,
|
||||
aligner: str | None,
|
||||
masker: list[str | None],
|
||||
recognition: str | None) -> list[str]:
|
||||
""" Set the flow list based on the input plugins
|
||||
|
||||
Parameters
|
||||
|
@ -441,7 +437,7 @@ class Extractor():
|
|||
return retval
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_type_and_index(flow_phase: str) -> Tuple[str, Optional[int]]:
|
||||
def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]:
|
||||
""" Obtain the plugin type and index for the plugin for the given flow phase.
|
||||
|
||||
When multiple plugins for the same phase are allowed (e.g. Mask) this will return
|
||||
|
@ -463,14 +459,14 @@ class Extractor():
|
|||
"""
|
||||
sidx = flow_phase.split("_")[-1]
|
||||
if sidx.isdigit():
|
||||
idx: Optional[int] = int(sidx)
|
||||
idx: int | None = int(sidx)
|
||||
plugin_type = "_".join(flow_phase.split("_")[:-1])
|
||||
else:
|
||||
plugin_type = flow_phase
|
||||
idx = None
|
||||
return plugin_type, idx
|
||||
|
||||
def _add_queues(self) -> Dict[str, EventQueue]:
|
||||
def _add_queues(self) -> dict[str, EventQueue]:
|
||||
""" Add the required processing queues to Queue Manager """
|
||||
queues = {}
|
||||
tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow]
|
||||
|
@ -483,7 +479,7 @@ class Extractor():
|
|||
return queues
|
||||
|
||||
@staticmethod
|
||||
def _get_vram_stats() -> Dict[str, Union[int, str]]:
|
||||
def _get_vram_stats() -> dict[str, int | str]:
|
||||
""" Obtain statistics on available VRAM and subtract a constant buffer from available vram.
|
||||
|
||||
Returns
|
||||
|
@ -494,10 +490,10 @@ class Extractor():
|
|||
vram_buffer = 256 # Leave a buffer for VRAM allocation
|
||||
gpu_stats = GPUStats()
|
||||
stats = gpu_stats.get_card_most_free()
|
||||
retval: Dict[str, Union[int, str]] = {"count": gpu_stats.device_count,
|
||||
"device": stats.device,
|
||||
"vram_free": int(stats.free - vram_buffer),
|
||||
"vram_total": int(stats.total)}
|
||||
retval: dict[str, int | str] = {"count": gpu_stats.device_count,
|
||||
"device": stats.device,
|
||||
"vram_free": int(stats.free - vram_buffer),
|
||||
"vram_total": int(stats.total)}
|
||||
logger.debug(retval)
|
||||
return retval
|
||||
|
||||
|
@ -521,13 +517,13 @@ class Extractor():
|
|||
self._vram_stats["device"],
|
||||
self._vram_stats["vram_free"],
|
||||
self._vram_stats["vram_total"])
|
||||
if cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
|
||||
if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required:
|
||||
logger.warning("Not enough free VRAM for parallel processing. "
|
||||
"Switching to serial")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _set_phases(self, multiprocess: bool) -> List[List[str]]:
|
||||
def _set_phases(self, multiprocess: bool) -> list[list[str]]:
|
||||
""" If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit
|
||||
into VRAM, otherwise return the single flow.
|
||||
|
||||
|
@ -541,9 +537,9 @@ class Extractor():
|
|||
list:
|
||||
The jobs to be undertaken split into phases that fit into GPU RAM
|
||||
"""
|
||||
phases: List[List[str]] = []
|
||||
current_phase: List[str] = []
|
||||
available = cast(int, self._vram_stats["vram_free"])
|
||||
phases: list[list[str]] = []
|
||||
current_phase: list[str] = []
|
||||
available = T.cast(int, self._vram_stats["vram_free"])
|
||||
for phase in self._flow:
|
||||
num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0])
|
||||
num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0
|
||||
|
@ -576,12 +572,12 @@ class Extractor():
|
|||
|
||||
# << INTERNAL PLUGIN HANDLING >> #
|
||||
def _load_align(self,
|
||||
aligner: Optional[str],
|
||||
configfile: Optional[str],
|
||||
normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]],
|
||||
aligner: str | None,
|
||||
configfile: str | None,
|
||||
normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None,
|
||||
re_feed: int,
|
||||
re_align: bool,
|
||||
disable_filter: bool) -> Optional["Aligner"]:
|
||||
disable_filter: bool) -> Aligner | None:
|
||||
""" Set global arguments and load aligner plugin
|
||||
|
||||
Parameters
|
||||
|
@ -619,10 +615,10 @@ class Extractor():
|
|||
return plugin
|
||||
|
||||
def _load_detect(self,
|
||||
detector: Optional[str],
|
||||
rotation: Optional[str],
|
||||
detector: str | None,
|
||||
rotation: str | None,
|
||||
min_size: int,
|
||||
configfile: Optional[str]) -> Optional["Detector"]:
|
||||
configfile: str | None) -> Detector | None:
|
||||
""" Set global arguments and load detector plugin """
|
||||
if detector is None or detector.lower() == "none":
|
||||
logger.debug("No detector selected. Returning None")
|
||||
|
@ -637,8 +633,8 @@ class Extractor():
|
|||
return plugin
|
||||
|
||||
def _load_mask(self,
|
||||
masker: Optional[str],
|
||||
configfile: Optional[str]) -> Optional["Masker"]:
|
||||
masker: str | None,
|
||||
configfile: str | None) -> Masker | None:
|
||||
""" Set global arguments and load masker plugin
|
||||
|
||||
Parameters
|
||||
|
@ -664,8 +660,8 @@ class Extractor():
|
|||
return plugin
|
||||
|
||||
def _load_recognition(self,
|
||||
recognition: Optional[str],
|
||||
configfile: Optional[str]) -> Optional["Identity"]:
|
||||
recognition: str | None,
|
||||
configfile: str | None) -> Identity | None:
|
||||
""" Set global arguments and load recognition plugin """
|
||||
if recognition is None or recognition.lower() == "none":
|
||||
logger.debug("No recognition selected. Returning None")
|
||||
|
@ -716,16 +712,16 @@ class Extractor():
|
|||
gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0]
|
||||
scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback)
|
||||
plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling
|
||||
if plugins_required + batch_required <= cast(int, self._vram_stats["vram_free"]):
|
||||
if plugins_required + batch_required <= T.cast(int, self._vram_stats["vram_free"]):
|
||||
logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, "
|
||||
"vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"])
|
||||
return
|
||||
# Hacky split across plugins that use vram
|
||||
available_vram = (cast(int, self._vram_stats["vram_free"])
|
||||
available_vram = (T.cast(int, self._vram_stats["vram_free"])
|
||||
- plugins_required) // len(gpu_plugins)
|
||||
self._set_plugin_batchsize(gpu_plugins, available_vram)
|
||||
|
||||
def _set_plugin_batchsize(self, gpu_plugins: List[str], available_vram: float) -> None:
|
||||
def _set_plugin_batchsize(self, gpu_plugins: list[str], available_vram: float) -> None:
|
||||
""" Set the batch size for the given plugin based on given available vram.
|
||||
Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to
|
||||
zero division error.
|
||||
|
@ -802,20 +798,20 @@ class ExtractMedia():
|
|||
|
||||
def __init__(self,
|
||||
filename: str,
|
||||
image: "np.ndarray",
|
||||
detected_faces: Optional[List["DetectedFace"]] = None,
|
||||
image: np.ndarray,
|
||||
detected_faces: list[DetectedFace] | None = None,
|
||||
is_aligned: bool = False) -> None:
|
||||
logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore
|
||||
"detected_faces: %s, is_aligned: %s)", self.__class__.__name__, filename,
|
||||
image.shape, detected_faces, is_aligned)
|
||||
self._filename = filename
|
||||
self._image: Optional["np.ndarray"] = image
|
||||
self._image_shape = cast(Tuple[int, int, int], image.shape)
|
||||
self._detected_faces: List["DetectedFace"] = ([] if detected_faces is None
|
||||
else detected_faces)
|
||||
self._image: np.ndarray | None = image
|
||||
self._image_shape = T.cast(tuple[int, int, int], image.shape)
|
||||
self._detected_faces: list[DetectedFace] = ([] if detected_faces is None
|
||||
else detected_faces)
|
||||
self._is_aligned = is_aligned
|
||||
self._frame_metadata: Optional["PNGHeaderSourceDict"] = None
|
||||
self._sub_folders: List[Optional[str]] = []
|
||||
self._frame_metadata: PNGHeaderSourceDict | None = None
|
||||
self._sub_folders: list[str | None] = []
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
|
@ -823,23 +819,23 @@ class ExtractMedia():
|
|||
return self._filename
|
||||
|
||||
@property
|
||||
def image(self) -> "np.ndarray":
|
||||
def image(self) -> np.ndarray:
|
||||
""" :class:`numpy.ndarray`: The source frame for this object. """
|
||||
assert self._image is not None
|
||||
return self._image
|
||||
|
||||
@property
|
||||
def image_shape(self) -> Tuple[int, int, int]:
|
||||
def image_shape(self) -> tuple[int, int, int]:
|
||||
""" tuple: The shape of the stored :attr:`image`. """
|
||||
return self._image_shape
|
||||
|
||||
@property
|
||||
def image_size(self) -> Tuple[int, int]:
|
||||
def image_size(self) -> tuple[int, int]:
|
||||
""" tuple: The (`height`, `width`) of the stored :attr:`image`. """
|
||||
return self._image_shape[:2]
|
||||
|
||||
@property
|
||||
def detected_faces(self) -> List["DetectedFace"]:
|
||||
def detected_faces(self) -> list[DetectedFace]:
|
||||
"""list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """
|
||||
return self._detected_faces
|
||||
|
||||
|
@ -849,7 +845,7 @@ class ExtractMedia():
|
|||
return self._is_aligned
|
||||
|
||||
@property
|
||||
def frame_metadata(self) -> "PNGHeaderSourceDict":
|
||||
def frame_metadata(self) -> PNGHeaderSourceDict:
|
||||
""" dict: The frame metadata that has been added from an aligned image. This property
|
||||
should only be called after :func:`add_frame_metadata` has been called when processing
|
||||
an aligned face. For all other instances an assertion error will be raised.
|
||||
|
@ -863,13 +859,13 @@ class ExtractMedia():
|
|||
return self._frame_metadata
|
||||
|
||||
@property
|
||||
def sub_folders(self) -> List[Optional[str]]:
|
||||
def sub_folders(self) -> list[str | None]:
|
||||
""" list: The sub_folders that the faces should be output to. Used when binning filter
|
||||
output is enabled. The list corresponds to the list of detected faces
|
||||
"""
|
||||
return self._sub_folders
|
||||
|
||||
def get_image_copy(self, color_format: Literal["BGR", "RGB", "GRAY"]) -> "np.ndarray":
|
||||
def get_image_copy(self, color_format: T.Literal["BGR", "RGB", "GRAY"]) -> np.ndarray:
|
||||
""" Get a copy of the image in the requested color format.
|
||||
|
||||
Parameters
|
||||
|
@ -887,7 +883,7 @@ class ExtractMedia():
|
|||
image = getattr(self, f"_image_as_{color_format.lower()}")()
|
||||
return image
|
||||
|
||||
def add_detected_faces(self, faces: List["DetectedFace"]) -> None:
|
||||
def add_detected_faces(self, faces: list[DetectedFace]) -> None:
|
||||
""" Add detected faces to the object. Called at the end of each extraction phase.
|
||||
|
||||
Parameters
|
||||
|
@ -900,7 +896,7 @@ class ExtractMedia():
|
|||
[(face.left, face.right, face.top, face.bottom) for face in faces])
|
||||
self._detected_faces = faces
|
||||
|
||||
def add_sub_folders(self, folders: List[Optional[str]]) -> None:
|
||||
def add_sub_folders(self, folders: list[str | None]) -> None:
|
||||
""" Add detected faces to the object. Called at the end of each extraction phase.
|
||||
|
||||
Parameters
|
||||
|
@ -922,7 +918,7 @@ class ExtractMedia():
|
|||
del self._image
|
||||
self._image = None
|
||||
|
||||
def set_image(self, image: "np.ndarray") -> None:
|
||||
def set_image(self, image: np.ndarray) -> None:
|
||||
""" Add the image back into :attr:`image`
|
||||
|
||||
Required for multi-phase extraction adds the image back to this object.
|
||||
|
@ -936,7 +932,7 @@ class ExtractMedia():
|
|||
self._filename, image.shape)
|
||||
self._image = image
|
||||
|
||||
def add_frame_metadata(self, metadata: "PNGHeaderSourceDict") -> None:
|
||||
def add_frame_metadata(self, metadata: PNGHeaderSourceDict) -> None:
|
||||
""" Add the source frame metadata from an aligned PNG's header data.
|
||||
|
||||
metadata: dict
|
||||
|
@ -944,11 +940,11 @@ class ExtractMedia():
|
|||
"""
|
||||
logger.trace("Adding PNG Source data for '%s': %s", # type:ignore
|
||||
self._filename, metadata)
|
||||
dims = cast(Tuple[int, int], metadata["source_frame_dims"])
|
||||
dims = T.cast(tuple[int, int], metadata["source_frame_dims"])
|
||||
self._image_shape = (*dims, 3)
|
||||
self._frame_metadata = metadata
|
||||
|
||||
def _image_as_bgr(self) -> "np.ndarray":
|
||||
def _image_as_bgr(self) -> np.ndarray:
|
||||
""" Get a copy of the source frame in BGR format.
|
||||
|
||||
Returns
|
||||
|
@ -957,7 +953,7 @@ class ExtractMedia():
|
|||
A copy of :attr:`image` in BGR color format """
|
||||
return self.image[..., :3].copy()
|
||||
|
||||
def _image_as_rgb(self) -> "np.ndarray":
|
||||
def _image_as_rgb(self) -> np.ndarray:
|
||||
""" Get a copy of the source frame in RGB format.
|
||||
|
||||
Returns
|
||||
|
@ -966,7 +962,7 @@ class ExtractMedia():
|
|||
A copy of :attr:`image` in RGB color format """
|
||||
return self.image[..., 2::-1].copy()
|
||||
|
||||
def _image_as_gray(self) -> "np.ndarray":
|
||||
def _image_as_gray(self) -> np.ndarray:
|
||||
""" Get a copy of the source frame in gray-scale format.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -17,7 +17,6 @@ To get a :class:`~lib.align.DetectedFace` object use the function:
|
|||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -31,13 +30,8 @@ from lib.utils import FaceswapError
|
|||
from plugins.extract._base import BatchType, Extractor, ExtractorBatch
|
||||
from plugins.extract.pipeline import ExtractMedia
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from queue import Queue
|
||||
from lib.align.aligned_face import CenteringType
|
||||
|
||||
|
@ -50,8 +44,8 @@ class RecogBatch(ExtractorBatch):
|
|||
|
||||
Inherits from :class:`~plugins.extract._base.ExtractorBatch`
|
||||
"""
|
||||
detected_faces: T.List["DetectedFace"] = field(default_factory=list)
|
||||
feed_faces: T.List[AlignedFace] = field(default_factory=list)
|
||||
detected_faces: list[DetectedFace] = field(default_factory=list)
|
||||
feed_faces: list[AlignedFace] = field(default_factory=list)
|
||||
|
||||
|
||||
class Identity(Extractor): # pylint:disable=abstract-method
|
||||
|
@ -82,9 +76,9 @@ class Identity(Extractor): # pylint:disable=abstract-method
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
git_model_id: T.Optional[int] = None,
|
||||
model_filename: T.Optional[str] = None,
|
||||
configfile: T.Optional[str] = None,
|
||||
git_model_id: int | None = None,
|
||||
model_filename: str | None = None,
|
||||
configfile: str | None = None,
|
||||
instance: int = 0,
|
||||
**kwargs):
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
|
@ -119,7 +113,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
|
|||
logger.debug("Obtained detected face: (filename: %s, detected_face: %s)",
|
||||
item.filename, item.detected_faces)
|
||||
|
||||
def get_batch(self, queue: Queue) -> T.Tuple[bool, RecogBatch]:
|
||||
def get_batch(self, queue: Queue) -> tuple[bool, RecogBatch]:
|
||||
""" Get items for inputting into the recognition from the queue in batches
|
||||
|
||||
Items are returned from the ``queue`` in batches of
|
||||
|
@ -226,7 +220,7 @@ class Identity(Extractor): # pylint:disable=abstract-method
|
|||
"\n3) Enable 'Single Process' mode.")
|
||||
raise FaceswapError(msg) from err
|
||||
|
||||
def finalize(self, batch: BatchType) -> T.Generator[ExtractMedia, None, None]:
|
||||
def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]:
|
||||
""" Finalize the output from Masker
|
||||
|
||||
This should be called as the final task of each `plugin`.
|
||||
|
@ -301,8 +295,8 @@ class IdentityFilter():
|
|||
def __init__(self, save_output: bool) -> None:
|
||||
logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output)
|
||||
self._save_output = save_output
|
||||
self._filter: T.Optional[np.ndarray] = None
|
||||
self._nfilter: T.Optional[np.ndarray] = None
|
||||
self._filter: np.ndarray | None = None
|
||||
self._nfilter: np.ndarray | None = None
|
||||
self._threshold = 0.0
|
||||
self._filter_enabled: bool = False
|
||||
self._nfilter_enabled: bool = False
|
||||
|
@ -357,7 +351,7 @@ class IdentityFilter():
|
|||
return retval
|
||||
|
||||
def _get_matches(self,
|
||||
filter_type: Literal["filter", "nfilter"],
|
||||
filter_type: T.Literal["filter", "nfilter"],
|
||||
identities: np.ndarray) -> np.ndarray:
|
||||
""" Obtain the average and minimum distances for each face against the source identities
|
||||
to test against
|
||||
|
@ -386,9 +380,9 @@ class IdentityFilter():
|
|||
return retval
|
||||
|
||||
def _filter_faces(self,
|
||||
faces: T.List[DetectedFace],
|
||||
sub_folders: T.List[T.Optional[str]],
|
||||
should_filter: T.List[bool]) -> T.List[DetectedFace]:
|
||||
faces: list[DetectedFace],
|
||||
sub_folders: list[str | None],
|
||||
should_filter: list[bool]) -> list[DetectedFace]:
|
||||
""" Filter the detected faces, either removing filtered faces from the list of detected
|
||||
faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if
|
||||
saving output is enabled.
|
||||
|
@ -410,7 +404,7 @@ class IdentityFilter():
|
|||
The filtered list of detected face objects, if saving filtered faces has not been
|
||||
selected or the full list of detected faces
|
||||
"""
|
||||
retval: T.List[DetectedFace] = []
|
||||
retval: list[DetectedFace] = []
|
||||
self._counts += sum(should_filter)
|
||||
for idx, face in enumerate(faces):
|
||||
fldr = sub_folders[idx]
|
||||
|
@ -429,8 +423,8 @@ class IdentityFilter():
|
|||
return retval
|
||||
|
||||
def __call__(self,
|
||||
faces: T.List[DetectedFace],
|
||||
sub_folders: T.List[T.Optional[str]]) -> T.List[DetectedFace]:
|
||||
faces: list[DetectedFace],
|
||||
sub_folders: list[str | None]) -> list[DetectedFace]:
|
||||
""" Call the identity filter function
|
||||
|
||||
Parameters
|
||||
|
@ -459,14 +453,14 @@ class IdentityFilter():
|
|||
logger.trace("All faces already filtered: %s", sub_folders) # type: ignore
|
||||
return faces
|
||||
|
||||
should_filter: T.List[np.ndarray] = []
|
||||
for f_type in get_args(Literal["filter", "nfilter"]):
|
||||
should_filter: list[np.ndarray] = []
|
||||
for f_type in T.get_args(T.Literal["filter", "nfilter"]):
|
||||
if not getattr(self, f"_{f_type}_enabled"):
|
||||
continue
|
||||
should_filter.append(self._get_matches(f_type, identities))
|
||||
|
||||
# If any of the filter or nfilter evaluate to 'should filter' then filter out face
|
||||
final_filter: T.List[bool] = np.array(should_filter).max(axis=0).tolist()
|
||||
final_filter: list[bool] = np.array(should_filter).max(axis=0).tolist()
|
||||
logger.trace("should_filter: %s, final_filter: %s", # type: ignore
|
||||
should_filter, final_filter)
|
||||
return self._filter_faces(faces, sub_folders, final_filter)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
#!/usr/bin python3
|
||||
""" VGG_Face2 inference and sorting """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from typing import cast, Dict, Generator, List, Tuple, Optional
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
|
@ -15,11 +14,8 @@ from lib.model.session import KSession
|
|||
from lib.utils import FaceswapError
|
||||
from ._base import BatchType, RecogBatch, Identity
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
@ -64,7 +60,7 @@ class Recognition(Identity):
|
|||
def init_model(self) -> None:
|
||||
""" Initialize VGG Face 2 Model. """
|
||||
assert isinstance(self.model_path, str)
|
||||
model_kwargs = dict(custom_objects={'L2_normalize': L2_normalize})
|
||||
model_kwargs = {"custom_objects": {"L2_normalize": L2_normalize}}
|
||||
self.model = KSession(self.name,
|
||||
self.model_path,
|
||||
model_kwargs=model_kwargs,
|
||||
|
@ -76,7 +72,7 @@ class Recognition(Identity):
|
|||
def process_input(self, batch: BatchType) -> None:
|
||||
""" Compile the detected faces for prediction """
|
||||
assert isinstance(batch, RecogBatch)
|
||||
batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3]
|
||||
batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3]
|
||||
for feed in batch.feed_faces],
|
||||
dtype="float32") - self._average_img
|
||||
logger.trace("feed shape: %s", batch.feed.shape) # type:ignore
|
||||
|
@ -121,15 +117,15 @@ class Cluster(): # pylint: disable=too-few-public-methods
|
|||
|
||||
def __init__(self,
|
||||
predictions: np.ndarray,
|
||||
method: Literal["single", "centroid", "median", "ward"],
|
||||
threshold: Optional[float] = None) -> None:
|
||||
method: T.Literal["single", "centroid", "median", "ward"],
|
||||
threshold: float | None = None) -> None:
|
||||
logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)",
|
||||
self.__class__.__name__, predictions.shape, method, threshold)
|
||||
self._num_predictions = predictions.shape[0]
|
||||
|
||||
self._should_output_bins = threshold is not None
|
||||
self._threshold = 0.0 if threshold is None else threshold
|
||||
self._bins: Dict[int, int] = {}
|
||||
self._bins: dict[int, int] = {}
|
||||
self._iterator = self._integer_iterator()
|
||||
|
||||
self._result_linkage = self._do_linkage(predictions, method)
|
||||
|
@ -192,7 +188,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
|
|||
|
||||
def _do_linkage(self,
|
||||
predictions: np.ndarray,
|
||||
method: Literal["single", "centroid", "median", "ward"]) -> np.ndarray:
|
||||
method: T.Literal["single", "centroid", "median", "ward"]) -> np.ndarray:
|
||||
""" Use FastCluster to perform vector or standard linkage
|
||||
|
||||
Parameters
|
||||
|
@ -218,7 +214,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
|
|||
|
||||
def _process_leaf_node(self,
|
||||
current_index: int,
|
||||
current_bin: int) -> List[Tuple[int, int]]:
|
||||
current_bin: int) -> list[tuple[int, int]]:
|
||||
""" Process the output when we have hit a leaf node """
|
||||
if not self._should_output_bins:
|
||||
return [(current_index, 0)]
|
||||
|
@ -263,7 +259,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
|
|||
tree: np.ndarray,
|
||||
points: int,
|
||||
current_index: int,
|
||||
current_bin: int = 0) -> List[Tuple[int, int]]:
|
||||
current_bin: int = 0) -> list[tuple[int, int]]:
|
||||
""" Seriation method for sorted similarity.
|
||||
|
||||
Seriation computes the order implied by a hierarchical tree (dendrogram).
|
||||
|
@ -298,7 +294,7 @@ class Cluster(): # pylint: disable=too-few-public-methods
|
|||
|
||||
return serate_left + serate_right # type: ignore
|
||||
|
||||
def __call__(self) -> List[Tuple[int, int]]:
|
||||
def __call__(self) -> list[tuple[int, int]]:
|
||||
""" Process the linkages.
|
||||
|
||||
Transforms a distance matrix into a sorted distance matrix according to the order implied
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
#!/usr/bin/env python3
|
||||
""" Plugin loader for Faceswap extract, training and convert tasks """
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Callable, List, Type, TYPE_CHECKING
|
||||
import typing as T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from importlib import import_module
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from plugins.extract.detect._base import Detector
|
||||
from plugins.extract.align._base import Aligner
|
||||
from plugins.extract.mask._base import Masker
|
||||
|
@ -15,11 +16,6 @@ if TYPE_CHECKING:
|
|||
from plugins.train.model._base import ModelBase
|
||||
from plugins.train.trainer._base import TrainerBase
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -36,7 +32,7 @@ class PluginLoader():
|
|||
>>> aligner = PluginLoader.get_aligner('cv2-dnn')
|
||||
"""
|
||||
@staticmethod
|
||||
def get_detector(name: str, disable_logging: bool = False) -> Type["Detector"]:
|
||||
def get_detector(name: str, disable_logging: bool = False) -> type[Detector]:
|
||||
""" Return requested detector plugin
|
||||
|
||||
Parameters
|
||||
|
@ -55,7 +51,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("extract.detect", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_aligner(name: str, disable_logging: bool = False) -> Type["Aligner"]:
|
||||
def get_aligner(name: str, disable_logging: bool = False) -> type[Aligner]:
|
||||
""" Return requested aligner plugin
|
||||
|
||||
Parameters
|
||||
|
@ -74,7 +70,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("extract.align", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_masker(name: str, disable_logging: bool = False) -> Type["Masker"]:
|
||||
def get_masker(name: str, disable_logging: bool = False) -> type[Masker]:
|
||||
""" Return requested masker plugin
|
||||
|
||||
Parameters
|
||||
|
@ -93,7 +89,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("extract.mask", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_recognition(name: str, disable_logging: bool = False) -> Type["Identity"]:
|
||||
def get_recognition(name: str, disable_logging: bool = False) -> type[Identity]:
|
||||
""" Return requested recognition plugin
|
||||
|
||||
Parameters
|
||||
|
@ -112,7 +108,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("extract.recognition", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]:
|
||||
def get_model(name: str, disable_logging: bool = False) -> type[ModelBase]:
|
||||
""" Return requested training model plugin
|
||||
|
||||
Parameters
|
||||
|
@ -131,7 +127,7 @@ class PluginLoader():
|
|||
return PluginLoader._import("train.model", name, disable_logging)
|
||||
|
||||
@staticmethod
|
||||
def get_trainer(name: str, disable_logging: bool = False) -> Type["TrainerBase"]:
|
||||
def get_trainer(name: str, disable_logging: bool = False) -> type[TrainerBase]:
|
||||
""" Return requested training trainer plugin
|
||||
|
||||
Parameters
|
||||
|
@ -198,9 +194,9 @@ class PluginLoader():
|
|||
return getattr(module, ttl)
|
||||
|
||||
@staticmethod
|
||||
def get_available_extractors(extractor_type: Literal["align", "detect", "mask"],
|
||||
def get_available_extractors(extractor_type: T.Literal["align", "detect", "mask"],
|
||||
add_none: bool = False,
|
||||
extend_plugin: bool = False) -> List[str]:
|
||||
extend_plugin: bool = False) -> list[str]:
|
||||
""" Return a list of available extractors of the given type
|
||||
|
||||
Parameters
|
||||
|
@ -243,7 +239,7 @@ class PluginLoader():
|
|||
return extractors
|
||||
|
||||
@staticmethod
|
||||
def get_available_models() -> List[str]:
|
||||
def get_available_models() -> list[str]:
|
||||
""" Return a list of available training models
|
||||
|
||||
Returns
|
||||
|
@ -273,7 +269,7 @@ class PluginLoader():
|
|||
return 'original' if 'original' in models else models[0]
|
||||
|
||||
@staticmethod
|
||||
def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> List[str]:
|
||||
def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> list[str]:
|
||||
""" Return a list of available converter plugins in the given category
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -249,6 +249,31 @@ class Config(FaceswapConfig):
|
|||
"NB: The value given here is the 'exponent' to the epsilon. For example, "
|
||||
"choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the epsilon "
|
||||
"to 0.001 (1e-3)."))
|
||||
self.add_item(
|
||||
section=section,
|
||||
title="save_optimizer",
|
||||
datatype=str,
|
||||
group=_("optimizer"),
|
||||
default="exit",
|
||||
fixed=False,
|
||||
gui_radio=True,
|
||||
choices=["never", "always", "exit"],
|
||||
info=_(
|
||||
"When to save the Optimizer Weights. Saving the optimizer weights is not "
|
||||
"necessary and will increase the model file size 3x (and by extension the amount "
|
||||
"of time it takes to save the model). However, it can be useful to save these "
|
||||
"weights if you want to guarantee that a resumed model carries off exactly from "
|
||||
"where it left off, rather than spending a few hundred iterations catching up."
|
||||
"\n\t never - Don't save optimizer weights."
|
||||
"\n\t always - Save the optimizer weights at every save iteration. Model saving "
|
||||
"will take longer, due to the increased file size, but you will always have the "
|
||||
"last saved optimizer state in your model file."
|
||||
"\n\t exit - Only save the optimizer weights when explicitly terminating a "
|
||||
"model. This can be when the model is actively stopped or when the target "
|
||||
"iterations are met. Note: If the training session ends because of another "
|
||||
"reason (e.g. power outage, Out of Memory Error, NaN detected) then the "
|
||||
"optimizer weights will NOT be saved."))
|
||||
|
||||
self.add_item(
|
||||
section=section,
|
||||
title="autoclip",
|
||||
|
|
|
@ -21,11 +21,6 @@ from tensorflow.keras.models import load_model, Model as KModel # noqa:E501 #
|
|||
from lib.model.backup_restore import Backup
|
||||
from lib.utils import FaceswapError
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from tensorflow import keras
|
||||
from .model import ModelBase
|
||||
|
@ -35,7 +30,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|||
|
||||
def get_all_sub_models(
|
||||
model: keras.models.Model,
|
||||
models: T.Optional[T.List[keras.models.Model]] = None) -> T.List[keras.models.Model]:
|
||||
models: list[keras.models.Model] | None = None) -> list[keras.models.Model]:
|
||||
""" For a given model, return all sub-models that occur (recursively) as children.
|
||||
|
||||
Parameters
|
||||
|
@ -85,12 +80,12 @@ class IO():
|
|||
plugin: ModelBase,
|
||||
model_dir: str,
|
||||
is_predict: bool,
|
||||
save_optimizer: Literal["never", "always", "exit"]) -> None:
|
||||
save_optimizer: T.Literal["never", "always", "exit"]) -> None:
|
||||
self._plugin = plugin
|
||||
self._is_predict = is_predict
|
||||
self._model_dir = model_dir
|
||||
self._save_optimizer = save_optimizer
|
||||
self._history: T.List[T.List[float]] = [[], []] # Loss histories per save iteration
|
||||
self._history: list[list[float]] = [[], []] # Loss histories per save iteration
|
||||
self._backup = Backup(self._model_dir, self._plugin.name)
|
||||
|
||||
@property
|
||||
|
@ -106,12 +101,12 @@ class IO():
|
|||
return os.path.isfile(self._filename)
|
||||
|
||||
@property
|
||||
def history(self) -> T.List[T.List[float]]:
|
||||
def history(self) -> list[list[float]]:
|
||||
""" list: list of loss histories per side for the current save iteration. """
|
||||
return self._history
|
||||
|
||||
@property
|
||||
def multiple_models_in_folder(self) -> T.Optional[T.List[str]]:
|
||||
def multiple_models_in_folder(self) -> list[str] | None:
|
||||
""" :list: or ``None`` If there are multiple model types in the requested folder, or model
|
||||
types that don't correspond to the requested plugin type, then returns the list of plugin
|
||||
names that exist in the folder, otherwise returns ``None`` """
|
||||
|
@ -210,7 +205,7 @@ class IO():
|
|||
msg += f" - Average loss since last save: {', '.join(lossmsg)}"
|
||||
logger.info(msg)
|
||||
|
||||
def _get_save_averages(self) -> T.List[float]:
|
||||
def _get_save_averages(self) -> list[float]:
|
||||
""" Return the average loss since the last save iteration and reset historical loss """
|
||||
logger.debug("Getting save averages")
|
||||
if not all(loss for loss in self._history):
|
||||
|
@ -222,7 +217,7 @@ class IO():
|
|||
logger.debug("Average losses since last save: %s", retval)
|
||||
return retval
|
||||
|
||||
def _should_backup(self, save_averages: T.List[float]) -> bool:
|
||||
def _should_backup(self, save_averages: list[float]) -> bool:
|
||||
""" Check whether the loss averages for this save iteration is the lowest that has been
|
||||
seen.
|
||||
|
||||
|
@ -301,7 +296,7 @@ class Weights():
|
|||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@classmethod
|
||||
def _check_weights_file(cls, weights_file: str) -> T.Optional[str]:
|
||||
def _check_weights_file(cls, weights_file: str) -> str | None:
|
||||
""" Validate that we have a valid path to a .h5 file.
|
||||
|
||||
Parameters
|
||||
|
@ -403,7 +398,7 @@ class Weights():
|
|||
"different settings than you have set for your current model.",
|
||||
skipped_ops)
|
||||
|
||||
def _get_weights_model(self) -> T.List[keras.models.Model]:
|
||||
def _get_weights_model(self) -> list[keras.models.Model]:
|
||||
""" Obtain a list of all sub-models contained within the weights model.
|
||||
|
||||
Returns
|
||||
|
@ -429,7 +424,7 @@ class Weights():
|
|||
def _load_layer_weights(self,
|
||||
layer: keras.layers.Layer,
|
||||
sub_weights: keras.layers.Layer,
|
||||
model_name: str) -> Literal[-1, 0, 1]:
|
||||
model_name: str) -> T.Literal[-1, 0, 1]:
|
||||
""" Load the weights for a single layer.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -29,18 +29,12 @@ from plugins.train._config import Config
|
|||
from .io import IO, get_all_sub_models, Weights
|
||||
from .settings import Loss, Optimizer, Settings
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
import argparse
|
||||
from lib.config import ConfigValueType
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
_CONFIG: T.Dict[str, ConfigValueType] = {}
|
||||
_CONFIG: dict[str, ConfigValueType] = {}
|
||||
|
||||
|
||||
class ModelBase():
|
||||
|
@ -79,13 +73,13 @@ class ModelBase():
|
|||
self.__class__.__name__, model_dir, arguments, predict)
|
||||
|
||||
# Input shape must be set within the plugin after initializing
|
||||
self.input_shape: T.Tuple[int, ...] = ()
|
||||
self.input_shape: tuple[int, ...] = ()
|
||||
self.trainer = "original" # Override for plugin specific trainer
|
||||
self.color_order: Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
|
||||
self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order
|
||||
|
||||
self._args = arguments
|
||||
self._is_predict = predict
|
||||
self._model: T.Optional[tf.keras.models.Model] = None
|
||||
self._model: tf.keras.models.Model | None = None
|
||||
|
||||
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
|
||||
self._load_config()
|
||||
|
@ -100,14 +94,7 @@ class ModelBase():
|
|||
"use. Please select a mask or disable 'Learn Mask'.")
|
||||
|
||||
self._mixed_precision = self.config["mixed_precision"]
|
||||
# self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
|
||||
# TODO - Re-enable saving of optimizer once this bug is fixed:
|
||||
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
|
||||
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
|
||||
# File "h5py/h5d.pyx", line 87, in h5py.h5d.create
|
||||
# ValueError: Unable to create dataset (name already exists)
|
||||
|
||||
self._io = IO(self, model_dir, self._is_predict, "never")
|
||||
self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"])
|
||||
self._check_multiple_models()
|
||||
|
||||
self._state = State(model_dir,
|
||||
|
@ -175,16 +162,16 @@ class ModelBase():
|
|||
return self.name
|
||||
|
||||
@property
|
||||
def input_shapes(self) -> T.List[T.Tuple[None, int, int, int]]:
|
||||
def input_shapes(self) -> list[tuple[None, int, int, int]]:
|
||||
""" list: A flattened list corresponding to all of the inputs to the model. """
|
||||
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(inputs))
|
||||
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(inputs))
|
||||
for inputs in self.model.inputs]
|
||||
return shapes
|
||||
|
||||
@property
|
||||
def output_shapes(self) -> T.List[T.Tuple[None, int, int, int]]:
|
||||
def output_shapes(self) -> list[tuple[None, int, int, int]]:
|
||||
""" list: A flattened list corresponding to all of the outputs of the model. """
|
||||
shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(output))
|
||||
shapes = [T.cast(tuple[None, int, int, int], K.int_shape(output))
|
||||
for output in self.model.outputs]
|
||||
return shapes
|
||||
|
||||
|
@ -333,7 +320,7 @@ class ModelBase():
|
|||
a list of 2 shape tuples of 3 dimensions. """
|
||||
assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple"
|
||||
|
||||
def _get_inputs(self) -> T.List[tf.keras.layers.Input]:
|
||||
def _get_inputs(self) -> list[tf.keras.layers.Input]:
|
||||
""" Obtain the standardized inputs for the model.
|
||||
|
||||
The inputs will be returned for the "A" and "B" sides in the shape as defined by
|
||||
|
@ -352,7 +339,7 @@ class ModelBase():
|
|||
logger.debug("inputs: %s", inputs)
|
||||
return inputs
|
||||
|
||||
def build_model(self, inputs: T.List[tf.keras.layers.Input]) -> tf.keras.models.Model:
|
||||
def build_model(self, inputs: list[tf.keras.layers.Input]) -> tf.keras.models.Model:
|
||||
""" Override for Model Specific autoencoder builds.
|
||||
|
||||
Parameters
|
||||
|
@ -427,7 +414,7 @@ class ModelBase():
|
|||
self._state.add_session_loss_names(self._loss.names)
|
||||
logger.debug("Compiled Model: %s", self.model)
|
||||
|
||||
def _legacy_mapping(self) -> T.Optional[dict]:
|
||||
def _legacy_mapping(self) -> dict | None:
|
||||
""" The mapping of separate model files to single model layers for transferring of legacy
|
||||
weights.
|
||||
|
||||
|
@ -439,7 +426,7 @@ class ModelBase():
|
|||
"""
|
||||
return None
|
||||
|
||||
def add_history(self, loss: T.List[float]) -> None:
|
||||
def add_history(self, loss: list[float]) -> None:
|
||||
""" Add the current iteration's loss history to :attr:`_io.history`.
|
||||
|
||||
Called from the trainer after each iteration, for tracking loss drop over time between
|
||||
|
@ -482,18 +469,18 @@ class State():
|
|||
self._filename = os.path.join(model_dir, filename)
|
||||
self._name = model_name
|
||||
self._iterations = 0
|
||||
self._mixed_precision_layers: T.List[str] = []
|
||||
self._mixed_precision_layers: list[str] = []
|
||||
self._rebuild_model = False
|
||||
self._sessions: T.Dict[int, dict] = {}
|
||||
self._lowest_avg_loss: T.Dict[str, float] = {}
|
||||
self._config: T.Dict[str, ConfigValueType] = {}
|
||||
self._sessions: dict[int, dict] = {}
|
||||
self._lowest_avg_loss: dict[str, float] = {}
|
||||
self._config: dict[str, ConfigValueType] = {}
|
||||
self._load(config_changeable_items)
|
||||
self._session_id = self._new_session_id()
|
||||
self._create_new_session(no_logs, config_changeable_items)
|
||||
logger.debug("Initialized %s:", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def loss_names(self) -> T.List[str]:
|
||||
def loss_names(self) -> list[str]:
|
||||
""" list: The loss names for the current session """
|
||||
return self._sessions[self._session_id]["loss_names"]
|
||||
|
||||
|
@ -518,7 +505,7 @@ class State():
|
|||
return self._session_id
|
||||
|
||||
@property
|
||||
def mixed_precision_layers(self) -> T.List[str]:
|
||||
def mixed_precision_layers(self) -> list[str]:
|
||||
"""list: Layers that can be switched between mixed-float16 and float32. """
|
||||
return self._mixed_precision_layers
|
||||
|
||||
|
@ -564,7 +551,7 @@ class State():
|
|||
"iterations": 0,
|
||||
"config": config_changeable_items}
|
||||
|
||||
def add_session_loss_names(self, loss_names: T.List[str]) -> None:
|
||||
def add_session_loss_names(self, loss_names: list[str]) -> None:
|
||||
""" Add the session loss names to the sessions dictionary.
|
||||
|
||||
The loss names are used for Tensorboard logging
|
||||
|
@ -593,7 +580,7 @@ class State():
|
|||
self._iterations += 1
|
||||
self._sessions[self._session_id]["iterations"] += 1
|
||||
|
||||
def add_mixed_precision_layers(self, layers: T.List[str]) -> None:
|
||||
def add_mixed_precision_layers(self, layers: list[str]) -> None:
|
||||
""" Add the list of model's layers that are compatible for mixed precision to the
|
||||
state dictionary """
|
||||
logger.debug("Storing mixed precision layers: %s", layers)
|
||||
|
@ -655,11 +642,11 @@ class State():
|
|||
legacy_update = self._update_legacy_config()
|
||||
# Add any new items to state config for legacy purposes where the new default may be
|
||||
# detrimental to an existing model.
|
||||
legacy_defaults: T.Dict[str, T.Union[str, int, bool]] = {"centering": "legacy",
|
||||
"mask_loss_function": "mse",
|
||||
"l2_reg_term": 100,
|
||||
"optimizer": "adam",
|
||||
"mixed_precision": False}
|
||||
legacy_defaults: dict[str, str | int | bool] = {"centering": "legacy",
|
||||
"mask_loss_function": "mse",
|
||||
"l2_reg_term": 100,
|
||||
"optimizer": "adam",
|
||||
"mixed_precision": False}
|
||||
for key, val in _CONFIG.items():
|
||||
if key not in self._config.keys():
|
||||
setting: ConfigValueType = legacy_defaults.get(key, val)
|
||||
|
@ -807,7 +794,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
|
|||
""" :class:`keras.models.Model`: The Faceswap model, compiled for inference. """
|
||||
return self._model
|
||||
|
||||
def _get_nodes(self, nodes: np.ndarray) -> T.List[T.Tuple[str, int]]:
|
||||
def _get_nodes(self, nodes: np.ndarray) -> list[tuple[str, int]]:
|
||||
""" Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary,
|
||||
filters the layer name(s) and output index of the node, splitting to the correct output
|
||||
index in the event of multiple inputs.
|
||||
|
@ -849,7 +836,7 @@ class _Inference(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Compiling inference model. saved_model: %s", saved_model)
|
||||
struct = self._get_filtered_structure()
|
||||
model_inputs = self._get_inputs(saved_model.inputs)
|
||||
compiled_layers: T.Dict[str, tf.keras.layers.Layer] = {}
|
||||
compiled_layers: dict[str, tf.keras.layers.Layer] = {}
|
||||
for layer in saved_model.layers:
|
||||
if layer.name not in struct:
|
||||
logger.debug("Skipping unused layer: '%s'", layer.name)
|
||||
|
|
|
@ -14,7 +14,6 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
@ -28,12 +27,9 @@ from lib.model import losses, optimizers
|
|||
from lib.model.autoclip import AutoClipper
|
||||
from lib.utils import get_backend
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager as ContextManager
|
||||
from argparse import Namespace
|
||||
from .model import State
|
||||
|
||||
|
@ -58,9 +54,9 @@ class LossClass:
|
|||
kwargs: dict
|
||||
Any keyword arguments to supply to the loss function at initialization.
|
||||
"""
|
||||
function: T.Union[T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor], T.Any] = k_losses.mae
|
||||
function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | T.Any = k_losses.mae
|
||||
init: bool = True
|
||||
kwargs: T.Dict[str, T.Any] = field(default_factory=dict)
|
||||
kwargs: dict[str, T.Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class Loss():
|
||||
|
@ -73,13 +69,13 @@ class Loss():
|
|||
color_order: str
|
||||
Color order of the model. One of `"BGR"` or `"RGB"`
|
||||
"""
|
||||
def __init__(self, config: dict, color_order: Literal["bgr", "rgb"]) -> None:
|
||||
def __init__(self, config: dict, color_order: T.Literal["bgr", "rgb"]) -> None:
|
||||
logger.debug("Initializing %s: (color_order: %s)", self.__class__.__name__, color_order)
|
||||
self._config = config
|
||||
self._mask_channels = self._get_mask_channels()
|
||||
self._inputs: T.List[tf.keras.layers.Layer] = []
|
||||
self._names: T.List[str] = []
|
||||
self._funcs: T.Dict[str, T.Callable] = {}
|
||||
self._inputs: list[tf.keras.layers.Layer] = []
|
||||
self._names: list[str] = []
|
||||
self._funcs: dict[str, Callable] = {}
|
||||
|
||||
self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss),
|
||||
"flip": LossClass(function=losses.LDRFLIPLoss,
|
||||
|
@ -104,7 +100,7 @@ class Loss():
|
|||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def names(self) -> T.List[str]:
|
||||
def names(self) -> list[str]:
|
||||
""" list: The list of loss names for the model. """
|
||||
return self._names
|
||||
|
||||
|
@ -114,14 +110,14 @@ class Loss():
|
|||
return self._funcs
|
||||
|
||||
@property
|
||||
def _mask_inputs(self) -> T.Optional[list]:
|
||||
def _mask_inputs(self) -> list | None:
|
||||
""" list: The list of input tensors to the model that contain the mask. Returns ``None``
|
||||
if there is no mask input to the model. """
|
||||
mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")]
|
||||
return None if not mask_inputs else mask_inputs
|
||||
|
||||
@property
|
||||
def _mask_shapes(self) -> T.Optional[T.List[tuple]]:
|
||||
def _mask_shapes(self) -> list[tuple] | None:
|
||||
""" list: The list of shape tuples for the mask input tensors for the model. Returns
|
||||
``None`` if there is no mask input. """
|
||||
if self._mask_inputs is None:
|
||||
|
@ -141,7 +137,7 @@ class Loss():
|
|||
self._set_loss_functions(model.output_names)
|
||||
self._names.insert(0, "total")
|
||||
|
||||
def _set_loss_names(self, outputs: T.List[tf.Tensor]) -> None:
|
||||
def _set_loss_names(self, outputs: list[tf.Tensor]) -> None:
|
||||
""" Name the losses based on model output.
|
||||
|
||||
This is used for correct naming in the state file, for display purposes only.
|
||||
|
@ -173,7 +169,7 @@ class Loss():
|
|||
self._names.append(f"{name}_{side}{suffix}")
|
||||
logger.debug(self._names)
|
||||
|
||||
def _get_function(self, name: str) -> T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
|
||||
def _get_function(self, name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
|
||||
""" Obtain the requested Loss function
|
||||
|
||||
Parameters
|
||||
|
@ -191,7 +187,7 @@ class Loss():
|
|||
logger.debug("Obtained loss function `%s` (%s)", name, retval)
|
||||
return retval
|
||||
|
||||
def _set_loss_functions(self, output_names: T.List[str]):
|
||||
def _set_loss_functions(self, output_names: list[str]):
|
||||
""" Set the loss functions and their associated weights.
|
||||
|
||||
Adds the loss functions to the :attr:`functions` dictionary.
|
||||
|
@ -251,7 +247,7 @@ class Loss():
|
|||
mask_channel=mask_channel)
|
||||
channel_idx += 1
|
||||
|
||||
def _get_mask_channels(self) -> T.List[int]:
|
||||
def _get_mask_channels(self) -> list[int]:
|
||||
""" Obtain the channels from the face targets that the masks reside in from the training
|
||||
data generator.
|
||||
|
||||
|
@ -311,8 +307,8 @@ class Optimizer(): # pylint:disable=too-few-public-methods
|
|||
{"beta_1": 0.5, "beta_2": 0.99, "epsilon": epsilon}),
|
||||
"rms-prop": (optimizers.RMSprop, {"epsilon": epsilon})}
|
||||
optimizer_info = valid_optimizers[optimizer]
|
||||
self._optimizer: T.Callable = optimizer_info[0]
|
||||
self._kwargs: T.Dict[str, T.Any] = optimizer_info[1]
|
||||
self._optimizer: Callable = optimizer_info[0]
|
||||
self._kwargs: dict[str, T.Any] = optimizer_info[1]
|
||||
|
||||
self._configure(learning_rate, autoclip)
|
||||
logger.verbose("Using %s optimizer", optimizer.title()) # type:ignore[attr-defined]
|
||||
|
@ -411,7 +407,7 @@ class Settings():
|
|||
return mixedprecision.LossScaleOptimizer(optimizer) # pylint:disable=no-member
|
||||
|
||||
@classmethod
|
||||
def _set_tf_settings(cls, allow_growth: bool, exclude_devices: T.List[int]) -> None:
|
||||
def _set_tf_settings(cls, allow_growth: bool, exclude_devices: list[int]) -> None:
|
||||
""" Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth.
|
||||
|
||||
Enables the Tensorflow allow_growth option if requested in the command line arguments
|
||||
|
@ -480,8 +476,8 @@ class Settings():
|
|||
return True
|
||||
|
||||
def _get_strategy(self,
|
||||
strategy: Literal["default", "central-storage", "mirrored"]
|
||||
) -> T.Optional[tf.distribute.Strategy]:
|
||||
strategy: T.Literal["default", "central-storage", "mirrored"]
|
||||
) -> tf.distribute.Strategy | None:
|
||||
""" If we are running on Nvidia backend and the strategy is not ``None`` then return
|
||||
the correct tensorflow distribution strategy, otherwise return ``None``.
|
||||
|
||||
|
@ -565,7 +561,7 @@ class Settings():
|
|||
|
||||
return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0")
|
||||
|
||||
def _get_mixed_precision_layers(self, layers: T.List[dict]) -> T.List[str]:
|
||||
def _get_mixed_precision_layers(self, layers: list[dict]) -> list[str]:
|
||||
""" Obtain the names of the layers in a mixed precision model that have their dtype policy
|
||||
explicitly set to mixed-float16.
|
||||
|
||||
|
@ -595,7 +591,7 @@ class Settings():
|
|||
logger.debug("Skipping unsupported layer: %s %s", layer["name"], dtype)
|
||||
return retval
|
||||
|
||||
def _switch_precision(self, layers: T.List[dict], compatible: T.List[str]) -> None:
|
||||
def _switch_precision(self, layers: list[dict], compatible: list[str]) -> None:
|
||||
""" Switch a model's datatype between mixed-float16 and float32.
|
||||
|
||||
Parameters
|
||||
|
@ -624,9 +620,9 @@ class Settings():
|
|||
config["dtype"] = policy
|
||||
|
||||
def get_mixed_precision_layers(self,
|
||||
build_func: T.Callable[[T.List[tf.keras.layers.Layer]],
|
||||
tf.keras.models.Model],
|
||||
inputs: T.List[tf.keras.layers.Layer]) -> T.List[str]:
|
||||
build_func: Callable[[list[tf.keras.layers.Layer]],
|
||||
tf.keras.models.Model],
|
||||
inputs: list[tf.keras.layers.Layer]) -> list[str]:
|
||||
""" Get and store the mixed precision layers from a full precision enabled model.
|
||||
|
||||
Parameters
|
||||
|
@ -699,7 +695,7 @@ class Settings():
|
|||
del model
|
||||
return new_model
|
||||
|
||||
def strategy_scope(self) -> T.ContextManager:
|
||||
def strategy_scope(self) -> ContextManager:
|
||||
""" Return the strategy scope if we have set a strategy, otherwise return a null
|
||||
context.
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# pylint: disable=too-many-lines
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -27,16 +26,10 @@ from lib.utils import get_tf_version, FaceswapError
|
|||
|
||||
from ._base import ModelBase, get_all_sub_models
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from tensorflow import keras
|
||||
from tensorflow import Tensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -65,14 +58,14 @@ class _EncoderInfo:
|
|||
"""
|
||||
keras_name: str
|
||||
default_size: int
|
||||
tf_min: T.Tuple[int, int] = (2, 0)
|
||||
scaling: T.Tuple[int, int] = (0, 1)
|
||||
tf_min: tuple[int, int] = (2, 0)
|
||||
scaling: tuple[int, int] = (0, 1)
|
||||
min_size: int = 32
|
||||
enforce_for_weights: bool = False
|
||||
color_order: Literal["bgr", "rgb"] = "rgb"
|
||||
color_order: T.Literal["bgr", "rgb"] = "rgb"
|
||||
|
||||
|
||||
_MODEL_MAPPING: T.Dict[str, _EncoderInfo] = {
|
||||
_MODEL_MAPPING: dict[str, _EncoderInfo] = {
|
||||
"densenet121": _EncoderInfo(
|
||||
keras_name="DenseNet121", default_size=224),
|
||||
"densenet169": _EncoderInfo(
|
||||
|
@ -238,7 +231,7 @@ class Model(ModelBase):
|
|||
model = new_model
|
||||
return model
|
||||
|
||||
def _select_freeze_layers(self) -> T.List[str]:
|
||||
def _select_freeze_layers(self) -> list[str]:
|
||||
""" Process the selected frozen layers and replace the `keras_encoder` option with the
|
||||
actual keras model name
|
||||
|
||||
|
@ -262,7 +255,7 @@ class Model(ModelBase):
|
|||
logger.debug("Removing 'keras_encoder' for '%s'", arch)
|
||||
return retval
|
||||
|
||||
def _get_input_shape(self) -> T.Tuple[int, int, int]:
|
||||
def _get_input_shape(self) -> tuple[int, int, int]:
|
||||
""" Obtain the input shape for the model.
|
||||
|
||||
Input shape is calculated from the selected Encoder's input size, scaled to the user
|
||||
|
@ -316,7 +309,7 @@ class Model(ModelBase):
|
|||
f"minimum version required is {tf_min} whilst you have version "
|
||||
f"{tf_ver} installed.")
|
||||
|
||||
def build_model(self, inputs: T.List[Tensor]) -> keras.models.Model:
|
||||
def build_model(self, inputs: list[Tensor]) -> keras.models.Model:
|
||||
""" Create the model's structure.
|
||||
|
||||
Parameters
|
||||
|
@ -341,7 +334,7 @@ class Model(ModelBase):
|
|||
autoencoder = KModel(inputs, outputs, name=self.model_name)
|
||||
return autoencoder
|
||||
|
||||
def _build_encoders(self, inputs: T.List[Tensor]) -> T.Dict[str, keras.models.Model]:
|
||||
def _build_encoders(self, inputs: list[Tensor]) -> dict[str, keras.models.Model]:
|
||||
""" Build the encoders for Phaze-A
|
||||
|
||||
Parameters
|
||||
|
@ -362,7 +355,7 @@ class Model(ModelBase):
|
|||
|
||||
def _build_fully_connected(
|
||||
self,
|
||||
inputs: T.Dict[str, keras.models.Model]) -> T.Dict[str, T.List[keras.models.Model]]:
|
||||
inputs: dict[str, keras.models.Model]) -> dict[str, list[keras.models.Model]]:
|
||||
""" Build the fully connected layers for Phaze-A
|
||||
|
||||
Parameters
|
||||
|
@ -407,8 +400,8 @@ class Model(ModelBase):
|
|||
|
||||
def _build_g_blocks(
|
||||
self,
|
||||
inputs: T.Dict[str, T.List[keras.models.Model]]
|
||||
) -> T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]:
|
||||
inputs: dict[str, list[keras.models.Model]]
|
||||
) -> dict[str, list[keras.models.Model] | keras.models.Model]:
|
||||
""" Build the g-block layers for Phaze-A.
|
||||
|
||||
If a g-block has not been selected for this model, then the original `inters` models are
|
||||
|
@ -440,10 +433,9 @@ class Model(ModelBase):
|
|||
logger.debug("G-Blocks: %s", retval)
|
||||
return retval
|
||||
|
||||
def _build_decoders(
|
||||
self,
|
||||
inputs: T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]
|
||||
) -> T.Dict[str, keras.models.Model]:
|
||||
def _build_decoders(self,
|
||||
inputs: dict[str, list[keras.models.Model] | keras.models.Model]
|
||||
) -> dict[str, keras.models.Model]:
|
||||
""" Build the encoders for Phaze-A
|
||||
|
||||
Parameters
|
||||
|
@ -519,12 +511,12 @@ def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str)
|
|||
return var_x
|
||||
|
||||
|
||||
def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny", "upscale_fast",
|
||||
"upscale_hybrid", "upsample2d"],
|
||||
def _get_upscale_layer(method: T.Literal["resize_images", "subpixel", "upscale_dny",
|
||||
"upscale_fast", "upscale_hybrid", "upsample2d"],
|
||||
filters: int,
|
||||
activation: T.Optional[str] = None,
|
||||
upsamples: T.Optional[int] = None,
|
||||
interpolation: T.Optional[str] = None) -> keras.layers.Layer:
|
||||
activation: str | None = None,
|
||||
upsamples: int | None = None,
|
||||
interpolation: str | None = None) -> keras.layers.Layer:
|
||||
""" Obtain an instance of the requested upscale method.
|
||||
|
||||
Parameters
|
||||
|
@ -550,7 +542,7 @@ def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny
|
|||
The selected configured upscale layer
|
||||
"""
|
||||
if method == "upsample2d":
|
||||
kwargs: T.Dict[str, T.Union[str, int]] = {}
|
||||
kwargs: dict[str, str | int] = {}
|
||||
if upsamples:
|
||||
kwargs["size"] = upsamples
|
||||
if interpolation:
|
||||
|
@ -571,7 +563,7 @@ def _get_curve(start_y: int,
|
|||
end_y: int,
|
||||
num_points: int,
|
||||
scale: float,
|
||||
mode: Literal["full", "cap_max", "cap_min"] = "full") -> T.List[int]:
|
||||
mode: T.Literal["full", "cap_max", "cap_min"] = "full") -> list[int]:
|
||||
""" Obtain a curve.
|
||||
|
||||
For the given start and end y values, return the y co-ordinates of a curve for the given
|
||||
|
@ -660,13 +652,13 @@ class Encoder(): # pylint:disable=too-few-public-methods
|
|||
config: dict
|
||||
The model configuration options
|
||||
"""
|
||||
def __init__(self, input_shape: T.Tuple[int, ...], config: dict) -> None:
|
||||
def __init__(self, input_shape: tuple[int, ...], config: dict) -> None:
|
||||
self.input_shape = input_shape
|
||||
self._config = config
|
||||
self._input_shape = input_shape
|
||||
|
||||
@property
|
||||
def _model_kwargs(self) -> T.Dict[str, T.Dict[str, T.Union[str, bool]]]:
|
||||
def _model_kwargs(self) -> dict[str, dict[str, str | bool]]:
|
||||
""" dict: Configuration option for architecture mapped to optional kwargs. """
|
||||
return {"mobilenet": {"alpha": self._config["mobilenet_width"],
|
||||
"depth_multiplier": self._config["mobilenet_depth"],
|
||||
|
@ -677,7 +669,7 @@ class Encoder(): # pylint:disable=too-few-public-methods
|
|||
"include_preprocessing": False}}
|
||||
|
||||
@property
|
||||
def _selected_model(self) -> T.Tuple[_EncoderInfo, dict]:
|
||||
def _selected_model(self) -> tuple[_EncoderInfo, dict]:
|
||||
""" tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated
|
||||
keyword arguments """
|
||||
arch = self._config["enc_architecture"]
|
||||
|
@ -832,7 +824,7 @@ class FullyConnected(): # pylint:disable=too-few-public-methods
|
|||
The user configuration dictionary
|
||||
"""
|
||||
def __init__(self,
|
||||
side: Literal["a", "b", "both", "gblock", "shared"],
|
||||
side: T.Literal["a", "b", "both", "gblock", "shared"],
|
||||
input_shape: tuple,
|
||||
config: dict) -> None:
|
||||
logger.debug("Initializing: %s (side: %s, input_shape: %s)",
|
||||
|
@ -992,12 +984,12 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
|
|||
and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will
|
||||
generate the layers from the starting index to the final upscale. Default: ``None``
|
||||
"""
|
||||
_filters: T.List[int] = []
|
||||
_filters: list[int] = []
|
||||
|
||||
def __init__(self,
|
||||
side: Literal["a", "b", "both", "shared"],
|
||||
side: T.Literal["a", "b", "both", "shared"],
|
||||
config: dict,
|
||||
layer_indicies: T.Optional[T.Tuple[int, int]] = None) -> None:
|
||||
layer_indicies: tuple[int, int] | None = None) -> None:
|
||||
logger.debug("Initializing: %s (side: %s, layer_indicies: %s)",
|
||||
self.__class__.__name__, side, layer_indicies)
|
||||
self._side = side
|
||||
|
@ -1126,7 +1118,7 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods
|
|||
relu_alpha=0.2)(var_x)
|
||||
return var_x
|
||||
|
||||
def __call__(self, inputs: T.Union[Tensor, T.List[Tensor]]) -> T.Union[Tensor, T.List[Tensor]]:
|
||||
def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]:
|
||||
""" Upscale Network.
|
||||
|
||||
Parameters
|
||||
|
@ -1203,8 +1195,8 @@ class GBlock(): # pylint:disable=too-few-public-methods
|
|||
The user configuration dictionary
|
||||
"""
|
||||
def __init__(self,
|
||||
side: Literal["a", "b", "both"],
|
||||
input_shapes: T.Union[list, tuple],
|
||||
side: T.Literal["a", "b", "both"],
|
||||
input_shapes: list | tuple,
|
||||
config: dict) -> None:
|
||||
logger.debug("Initializing: %s (side: %s, input_shapes: %s)",
|
||||
self.__class__.__name__, side, input_shapes)
|
||||
|
@ -1284,8 +1276,8 @@ class Decoder(): # pylint:disable=too-few-public-methods
|
|||
The user configuration dictionary
|
||||
"""
|
||||
def __init__(self,
|
||||
side: Literal["a", "b", "both"],
|
||||
input_shape: T.Tuple[int, int, int],
|
||||
side: T.Literal["a", "b", "both"],
|
||||
input_shape: tuple[int, int, int],
|
||||
config: dict) -> None:
|
||||
logger.debug("Initializing: %s (side: %s, input_shape: %s)",
|
||||
self.__class__.__name__, side, input_shape)
|
||||
|
|
|
@ -39,8 +39,6 @@
|
|||
" the value saved in the state file with the updated value in config. If not
|
||||
" provided this will default to True.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
|
||||
_HELPTEXT: str = (
|
||||
"Phaze-A Model by TorzDF, with thanks to BirbFakes.\n"
|
||||
|
@ -48,7 +46,7 @@ _HELPTEXT: str = (
|
|||
"inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to "
|
||||
"understand the parameters better.")
|
||||
|
||||
_ENCODERS: List[str] = sorted([
|
||||
_ENCODERS: list[str] = sorted([
|
||||
"densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1",
|
||||
"efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6",
|
||||
"efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2",
|
||||
|
|
|
@ -9,7 +9,6 @@ with "original" unique code split out to the original plugin.
|
|||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import typing as T
|
||||
|
||||
|
@ -23,23 +22,19 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
|||
from lib.image import hex_to_rgb
|
||||
from lib.training import PreviewDataGenerator, TrainingDataGenerator
|
||||
from lib.training.generator import BatchType, DataGenerator
|
||||
from lib.utils import FaceswapError, get_folder, get_image_paths, get_tf_version
|
||||
from lib.utils import FaceswapError, get_folder, get_image_paths
|
||||
from plugins.train._config import Config
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from collections.abc import Callable, Generator
|
||||
from plugins.train.model._base import ModelBase
|
||||
from lib.config import ConfigValueType
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_config(plugin_name: str,
|
||||
configfile: T.Optional[str] = None) -> T.Dict[str, ConfigValueType]:
|
||||
configfile: str | None = None) -> dict[str, ConfigValueType]:
|
||||
""" Return the configuration for the requested trainer.
|
||||
|
||||
Parameters
|
||||
|
@ -80,9 +75,9 @@ class TrainerBase():
|
|||
|
||||
def __init__(self,
|
||||
model: ModelBase,
|
||||
images: T.Dict[Literal["a", "b"], T.List[str]],
|
||||
images: dict[T.Literal["a", "b"], list[str]],
|
||||
batch_size: int,
|
||||
configfile: T.Optional[str]) -> None:
|
||||
configfile: str | None) -> None:
|
||||
logger.debug("Initializing %s: (model: '%s', batch_size: %s)",
|
||||
self.__class__.__name__, model, batch_size)
|
||||
self._model = model
|
||||
|
@ -111,7 +106,7 @@ class TrainerBase():
|
|||
self._images)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def _get_config(self, configfile: T.Optional[str]) -> T.Dict[str, ConfigValueType]:
|
||||
def _get_config(self, configfile: str | None) -> dict[str, ConfigValueType]:
|
||||
""" Get the saved training config options. Override any global settings with the setting
|
||||
provided from the model's saved config.
|
||||
|
||||
|
@ -173,10 +168,9 @@ class TrainerBase():
|
|||
self._samples.toggle_mask_display()
|
||||
|
||||
def train_one_step(self,
|
||||
viewer: T.Optional[T.Callable[[np.ndarray, str], None]],
|
||||
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a",
|
||||
"input_b",
|
||||
"output"], str]]) -> None:
|
||||
viewer: Callable[[np.ndarray, str], None] | None,
|
||||
timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
|
||||
str] | None) -> None:
|
||||
""" Running training on a batch of images for each side.
|
||||
|
||||
Triggered from the training cycle in :class:`scripts.train.Train`.
|
||||
|
@ -217,7 +211,7 @@ class TrainerBase():
|
|||
model_inputs, model_targets = self._feeder.get_batch()
|
||||
|
||||
try:
|
||||
loss: T.List[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
|
||||
loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
|
||||
except tf_errors.ResourceExhaustedError as err:
|
||||
msg = ("You do not have enough GPU memory available to train the selected model at "
|
||||
"the selected settings. You can try a number of things:"
|
||||
|
@ -236,7 +230,7 @@ class TrainerBase():
|
|||
self._model.snapshot()
|
||||
self._update_viewers(viewer, timelapse_kwargs)
|
||||
|
||||
def _log_tensorboard(self, loss: T.List[float]) -> None:
|
||||
def _log_tensorboard(self, loss: list[float]) -> None:
|
||||
""" Log current loss to Tensorboard log files
|
||||
|
||||
Parameters
|
||||
|
@ -250,19 +244,18 @@ class TrainerBase():
|
|||
logs = {log[0]: log[1]
|
||||
for log in zip(self._model.state.loss_names, loss)}
|
||||
|
||||
if get_tf_version() > (2, 7):
|
||||
# Bug in TF 2.8/2.9/2.10 where batch recording got deleted.
|
||||
# ref: https://github.com/keras-team/keras/issues/16173
|
||||
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa pylint:disable=protected-access,not-context-manager
|
||||
for name, value in logs.items():
|
||||
tf.summary.scalar(
|
||||
"batch_" + name,
|
||||
value,
|
||||
step=self._tensorboard._train_step) # pylint:disable=protected-access
|
||||
else:
|
||||
self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
|
||||
# Bug in TF 2.8/2.9/2.10 where batch recording got deleted.
|
||||
# ref: https://github.com/keras-team/keras/issues/16173
|
||||
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa:E501 pylint:disable=protected-access,not-context-manager
|
||||
for name, value in logs.items():
|
||||
tf.summary.scalar(
|
||||
"batch_" + name,
|
||||
value,
|
||||
step=self._tensorboard._train_step) # pylint:disable=protected-access
|
||||
# TODO revert this code if fixed in tensorflow
|
||||
# self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
|
||||
|
||||
def _collate_and_store_loss(self, loss: T.List[float]) -> T.List[float]:
|
||||
def _collate_and_store_loss(self, loss: list[float]) -> list[float]:
|
||||
""" Collate the loss into totals for each side.
|
||||
|
||||
The losses are summed into a total for each side. Loss totals are added to
|
||||
|
@ -298,7 +291,7 @@ class TrainerBase():
|
|||
logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore
|
||||
return combined_loss
|
||||
|
||||
def _print_loss(self, loss: T.List[float]) -> None:
|
||||
def _print_loss(self, loss: list[float]) -> None:
|
||||
""" Outputs the loss for the current iteration to the console.
|
||||
|
||||
Parameters
|
||||
|
@ -318,10 +311,9 @@ class TrainerBase():
|
|||
"line: %s, error: %s", output, str(err))
|
||||
|
||||
def _update_viewers(self,
|
||||
viewer: T.Optional[T.Callable[[np.ndarray, str], None]],
|
||||
timelapse_kwargs: T.Optional[T.Dict[Literal["input_a",
|
||||
"input_b",
|
||||
"output"], str]]) -> None:
|
||||
viewer: Callable[[np.ndarray, str], None] | None,
|
||||
timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"],
|
||||
str] | None) -> None:
|
||||
""" Update the preview viewer and timelapse output
|
||||
|
||||
Parameters
|
||||
|
@ -371,10 +363,10 @@ class _Feeder():
|
|||
The configuration for this trainer
|
||||
"""
|
||||
def __init__(self,
|
||||
images: T.Dict[Literal["a", "b"], T.List[str]],
|
||||
images: dict[T.Literal["a", "b"], list[str]],
|
||||
model: ModelBase,
|
||||
batch_size: int,
|
||||
config: T.Dict[str, ConfigValueType]) -> None:
|
||||
config: dict[str, ConfigValueType]) -> None:
|
||||
logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)",
|
||||
self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size,
|
||||
config)
|
||||
|
@ -383,16 +375,16 @@ class _Feeder():
|
|||
self._batch_size = batch_size
|
||||
self._config = config
|
||||
self._feeds = {side: self._load_generator(side, False).minibatch_ab()
|
||||
for side in get_args(Literal["a", "b"])}
|
||||
for side in T.get_args(T.Literal["a", "b"])}
|
||||
|
||||
self._display_feeds = {"preview": self._set_preview_feed(), "timelapse": {}}
|
||||
logger.debug("Initialized %s:", self.__class__.__name__)
|
||||
|
||||
def _load_generator(self,
|
||||
side: Literal["a", "b"],
|
||||
side: T.Literal["a", "b"],
|
||||
is_display: bool,
|
||||
batch_size: T.Optional[int] = None,
|
||||
images: T.Optional[T.List[str]] = None) -> DataGenerator:
|
||||
batch_size: int | None = None,
|
||||
images: list[str] | None = None) -> DataGenerator:
|
||||
""" Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder.
|
||||
|
||||
Parameters
|
||||
|
@ -424,7 +416,7 @@ class _Feeder():
|
|||
self._batch_size if batch_size is None else batch_size)
|
||||
return retval
|
||||
|
||||
def _set_preview_feed(self) -> T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]]:
|
||||
def _set_preview_feed(self) -> dict[T.Literal["a", "b"], Generator[BatchType, None, None]]:
|
||||
""" Set the preview feed for this feeder.
|
||||
|
||||
Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically
|
||||
|
@ -436,10 +428,10 @@ class _Feeder():
|
|||
The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as
|
||||
value.
|
||||
"""
|
||||
retval: T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]] = {}
|
||||
retval: dict[T.Literal["a", "b"], Generator[BatchType, None, None]] = {}
|
||||
num_images = self._config.get("preview_images", 14)
|
||||
assert isinstance(num_images, int)
|
||||
for side in get_args(Literal["a", "b"]):
|
||||
for side in T.get_args(T.Literal["a", "b"]):
|
||||
logger.debug("Setting preview feed: (side: '%s')", side)
|
||||
preview_images = min(max(num_images, 2), 16)
|
||||
batchsize = min(len(self._images[side]), preview_images)
|
||||
|
@ -448,7 +440,7 @@ class _Feeder():
|
|||
batch_size=batchsize).minibatch_ab()
|
||||
return retval
|
||||
|
||||
def get_batch(self) -> T.Tuple[T.List[T.List[np.ndarray]], ...]:
|
||||
def get_batch(self) -> tuple[list[list[np.ndarray]], ...]:
|
||||
""" Get the feed data and the targets for each training side for feeding into the model's
|
||||
train function.
|
||||
|
||||
|
@ -459,8 +451,8 @@ class _Feeder():
|
|||
model_targets: list
|
||||
The targets for the model for each side A and B
|
||||
"""
|
||||
model_inputs: T.List[T.List[np.ndarray]] = []
|
||||
model_targets: T.List[T.List[np.ndarray]] = []
|
||||
model_inputs: list[list[np.ndarray]] = []
|
||||
model_targets: list[list[np.ndarray]] = []
|
||||
for side in ("a", "b"):
|
||||
side_feed, side_targets = next(self._feeds[side])
|
||||
if self._model.config["learn_mask"]: # Add the face mask as it's own target
|
||||
|
@ -473,7 +465,7 @@ class _Feeder():
|
|||
return model_inputs, model_targets
|
||||
|
||||
def generate_preview(self, is_timelapse: bool = False
|
||||
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]:
|
||||
) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
|
||||
""" Generate the images for preview window or timelapse
|
||||
|
||||
Parameters
|
||||
|
@ -490,15 +482,15 @@ class _Feeder():
|
|||
"""
|
||||
logger.debug("Generating preview (is_timelapse: %s)", is_timelapse)
|
||||
|
||||
batchsizes: T.List[int] = []
|
||||
feed: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
samples: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
masks: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
batchsizes: list[int] = []
|
||||
feed: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
samples: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
masks: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
|
||||
# MyPy can't recurse into nested dicts to get the type :(
|
||||
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]],
|
||||
iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
|
||||
self._display_feeds["timelapse" if is_timelapse else "preview"])
|
||||
for side in get_args(Literal["a", "b"]):
|
||||
for side in T.get_args(T.Literal["a", "b"]):
|
||||
side_feed, side_samples = next(iterator[side])
|
||||
batchsizes.append(len(side_samples[0]))
|
||||
samples[side] = side_samples[0]
|
||||
|
@ -513,10 +505,10 @@ class _Feeder():
|
|||
|
||||
def compile_sample(self,
|
||||
image_count: int,
|
||||
feed: T.Dict[Literal["a", "b"], np.ndarray],
|
||||
samples: T.Dict[Literal["a", "b"], np.ndarray],
|
||||
masks: T.Dict[Literal["a", "b"], np.ndarray]
|
||||
) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]:
|
||||
feed: dict[T.Literal["a", "b"], np.ndarray],
|
||||
samples: dict[T.Literal["a", "b"], np.ndarray],
|
||||
masks: dict[T.Literal["a", "b"], np.ndarray]
|
||||
) -> dict[T.Literal["a", "b"], list[np.ndarray]]:
|
||||
""" Compile the preview samples for display.
|
||||
|
||||
Parameters
|
||||
|
@ -542,8 +534,8 @@ class _Feeder():
|
|||
num_images = self._config.get("preview_images", 14)
|
||||
assert isinstance(num_images, int)
|
||||
num_images = min(image_count, num_images)
|
||||
retval: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {}
|
||||
for side in get_args(Literal["a", "b"]):
|
||||
retval: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
|
||||
for side in T.get_args(T.Literal["a", "b"]):
|
||||
logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images)
|
||||
retval[side] = [feed[side][0:num_images],
|
||||
samples[side][0:num_images],
|
||||
|
@ -552,7 +544,7 @@ class _Feeder():
|
|||
return retval
|
||||
|
||||
def set_timelapse_feed(self,
|
||||
images: T.Dict[Literal["a", "b"], T.List[str]],
|
||||
images: dict[T.Literal["a", "b"], list[str]],
|
||||
batch_size: int) -> None:
|
||||
""" Set the time-lapse feed for this feeder.
|
||||
|
||||
|
@ -570,10 +562,10 @@ class _Feeder():
|
|||
images, batch_size)
|
||||
|
||||
# MyPy can't recurse into nested dicts to get the type :(
|
||||
iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]],
|
||||
iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"],
|
||||
self._display_feeds["timelapse"])
|
||||
|
||||
for side in get_args(Literal["a", "b"]):
|
||||
for side in T.get_args(T.Literal["a", "b"]):
|
||||
imgs = images[side]
|
||||
logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs))
|
||||
|
||||
|
@ -615,7 +607,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color)
|
||||
self._model = model
|
||||
self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"]
|
||||
self.images: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {}
|
||||
self.images: dict[T.Literal["a", "b"], list[np.ndarray]] = {}
|
||||
self._coverage_ratio = coverage_ratio
|
||||
self._mask_opacity = mask_opacity / 100.0
|
||||
self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255.
|
||||
|
@ -639,8 +631,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
A compiled preview image ready for display or saving
|
||||
"""
|
||||
logger.debug("Showing sample")
|
||||
feeds: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
for idx, side in enumerate(get_args(Literal["a", "b"])):
|
||||
feeds: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
for idx, side in enumerate(T.get_args(T.Literal["a", "b"])):
|
||||
feed = self.images[side][0]
|
||||
input_shape = self._model.model.input_shape[idx][1:]
|
||||
if input_shape[0] / feed.shape[1] != 1.0:
|
||||
|
@ -653,7 +645,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
|
||||
@classmethod
|
||||
def _resize_sample(cls,
|
||||
side: Literal["a", "b"],
|
||||
side: T.Literal["a", "b"],
|
||||
sample: np.ndarray,
|
||||
target_size: int) -> np.ndarray:
|
||||
""" Resize a given image to the target size.
|
||||
|
@ -684,7 +676,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape)
|
||||
return retval
|
||||
|
||||
def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> T.Dict[str, np.ndarray]:
|
||||
def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> dict[str, np.ndarray]:
|
||||
""" Feed the samples to the model and return predictions
|
||||
|
||||
Parameters
|
||||
|
@ -700,7 +692,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
List of :class:`numpy.ndarray` of predictions received from the model
|
||||
"""
|
||||
logger.debug("Getting Predictions")
|
||||
preds: T.Dict[str, np.ndarray] = {}
|
||||
preds: dict[str, np.ndarray] = {}
|
||||
standard = self._model.model.predict([feed_a, feed_b], verbose=0)
|
||||
swapped = self._model.model.predict([feed_b, feed_a], verbose=0)
|
||||
|
||||
|
@ -719,7 +711,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()})
|
||||
return preds
|
||||
|
||||
def _compile_preview(self, predictions: T.Dict[str, np.ndarray]) -> np.ndarray:
|
||||
def _compile_preview(self, predictions: dict[str, np.ndarray]) -> np.ndarray:
|
||||
""" Compile predictions and images into the final preview image.
|
||||
|
||||
Parameters
|
||||
|
@ -732,8 +724,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
:class:`numpy.ndarry`
|
||||
A compiled preview image ready for display or saving
|
||||
"""
|
||||
figures: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
headers: T.Dict[Literal["a", "b"], np.ndarray] = {}
|
||||
figures: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
headers: dict[T.Literal["a", "b"], np.ndarray] = {}
|
||||
|
||||
for side, samples in self.images.items():
|
||||
other_side = "a" if side == "b" else "b"
|
||||
|
@ -761,9 +753,9 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
return np.clip(figure * 255, 0, 255).astype('uint8')
|
||||
|
||||
def _to_full_frame(self,
|
||||
side: Literal["a", "b"],
|
||||
samples: T.List[np.ndarray],
|
||||
predictions: T.List[np.ndarray]) -> T.List[np.ndarray]:
|
||||
side: T.Literal["a", "b"],
|
||||
samples: list[np.ndarray],
|
||||
predictions: list[np.ndarray]) -> list[np.ndarray]:
|
||||
""" Patch targets and prediction images into images of model output size.
|
||||
|
||||
Parameters
|
||||
|
@ -803,10 +795,10 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
return images
|
||||
|
||||
def _process_full(self,
|
||||
side: Literal["a", "b"],
|
||||
side: T.Literal["a", "b"],
|
||||
images: np.ndarray,
|
||||
prediction_size: int,
|
||||
color: T.Tuple[float, float, float]) -> np.ndarray:
|
||||
color: tuple[float, float, float]) -> np.ndarray:
|
||||
""" Add a frame overlay to preview images indicating the region of interest.
|
||||
|
||||
This applies the red border that appears in the preview images.
|
||||
|
@ -847,7 +839,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Overlayed background. Shape: %s", images.shape)
|
||||
return images
|
||||
|
||||
def _compile_masked(self, faces: T.List[np.ndarray], masks: np.ndarray) -> T.List[np.ndarray]:
|
||||
def _compile_masked(self, faces: list[np.ndarray], masks: np.ndarray) -> list[np.ndarray]:
|
||||
""" Add the mask to the faces for masked preview.
|
||||
|
||||
Places an opaque red layer over areas of the face that are masked out.
|
||||
|
@ -866,7 +858,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
List of :class:`numpy.ndarray` faces with the opaque mask layer applied
|
||||
"""
|
||||
orig_masks = 1 - np.rint(masks)
|
||||
masks3: T.Union[T.List[np.ndarray], np.ndarray] = []
|
||||
masks3: list[np.ndarray] | np.ndarray = []
|
||||
|
||||
if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions
|
||||
pred_masks = [1 - np.rint(face[..., -1])[..., None] for face in faces[-2:]]
|
||||
|
@ -875,7 +867,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
else:
|
||||
masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0)
|
||||
|
||||
retval: T.List[np.ndarray] = []
|
||||
retval: list[np.ndarray] = []
|
||||
alpha = 1.0 - self._mask_opacity
|
||||
for previews, compiled_masks in zip(faces, masks3):
|
||||
overlays = previews.copy()
|
||||
|
@ -910,7 +902,7 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
return backgrounds
|
||||
|
||||
@classmethod
|
||||
def _get_headers(cls, side: Literal["a", "b"], width: int) -> np.ndarray:
|
||||
def _get_headers(cls, side: T.Literal["a", "b"], width: int) -> np.ndarray:
|
||||
""" Set header row for the final preview frame
|
||||
|
||||
Parameters
|
||||
|
@ -958,8 +950,8 @@ class _Samples(): # pylint:disable=too-few-public-methods
|
|||
|
||||
@classmethod
|
||||
def _duplicate_headers(cls,
|
||||
headers: T.Dict[Literal["a", "b"], np.ndarray],
|
||||
columns: int) -> T.Dict[Literal["a", "b"], np.ndarray]:
|
||||
headers: dict[T.Literal["a", "b"], np.ndarray],
|
||||
columns: int) -> dict[T.Literal["a", "b"], np.ndarray]:
|
||||
""" Duplicate headers for the number of columns displayed for each side.
|
||||
|
||||
Parameters
|
||||
|
@ -1008,7 +1000,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
|
|||
mask_opacity: int,
|
||||
mask_color: str,
|
||||
feeder: _Feeder,
|
||||
image_paths: T.Dict[Literal["a", "b"], T.List[str]]) -> None:
|
||||
image_paths: dict[T.Literal["a", "b"], list[str]]) -> None:
|
||||
logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, "
|
||||
"mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)",
|
||||
self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity,
|
||||
|
@ -1042,8 +1034,8 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Time-lapse output set to '%s'", self._output_file)
|
||||
|
||||
# Rewrite paths to pull from the training images so mask and face data can be accessed
|
||||
images: T.Dict[Literal["a", "b"], T.List[str]] = {}
|
||||
for side, input_ in zip(get_args(Literal["a", "b"]), (input_a, input_b)):
|
||||
images: dict[T.Literal["a", "b"], list[str]] = {}
|
||||
for side, input_ in zip(T.get_args(T.Literal["a", "b"]), (input_a, input_b)):
|
||||
training_path = os.path.dirname(self._image_paths[side][0])
|
||||
images[side] = [os.path.join(training_path, os.path.basename(pth))
|
||||
for pth in get_image_paths(input_)]
|
||||
|
@ -1054,7 +1046,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
|
|||
self._feeder.set_timelapse_feed(images, batchsize)
|
||||
logger.debug("Set up time-lapse")
|
||||
|
||||
def output_timelapse(self, timelapse_kwargs: T.Dict[Literal["input_a",
|
||||
def output_timelapse(self, timelapse_kwargs: dict[T.Literal["input_a",
|
||||
"input_b",
|
||||
"output"], str]) -> None:
|
||||
""" Generate the time-lapse samples and output the created time-lapse to the specified
|
||||
|
@ -1068,7 +1060,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
logger.debug("Ouputting time-lapse")
|
||||
if not self._output_file:
|
||||
self._setup(**T.cast(T.Dict[str, str], timelapse_kwargs))
|
||||
self._setup(**T.cast(dict[str, str], timelapse_kwargs))
|
||||
|
||||
logger.debug("Getting time-lapse samples")
|
||||
self._samples.images = self._feeder.generate_preview(is_timelapse=True)
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
tqdm>=4.64
|
||||
# TESTED WITH PY3.10
|
||||
tqdm>=4.65
|
||||
psutil>=5.9.0
|
||||
numexpr>=2.7.3; python_version < '3.9' # >=2.8.0 conflicts in Conda
|
||||
numexpr>=2.8.3; python_version >= '3.9'
|
||||
opencv-python>=4.6.0.0
|
||||
pillow>=9.2.0
|
||||
scikit-learn==1.0.2; python_version < '3.9' # AMD needs version 1.0.2 and 1.1.0 not available in Python 3.7
|
||||
scikit-learn>=1.1.0; python_version >= '3.9'
|
||||
numexpr>=2.8.4
|
||||
numpy>=1.25.0
|
||||
opencv-python>=4.7.0.0
|
||||
pillow>=9.4.0
|
||||
scikit-learn>=1.2.2
|
||||
fastcluster>=1.2.6
|
||||
matplotlib>=3.4.3,<3.6.0; python_version < '3.9' # >=3.5.0 conflicts in Conda
|
||||
matplotlib>=3.5.1,<3.6.0; python_version >= '3.9'
|
||||
imageio>=2.19.3
|
||||
imageio-ffmpeg>=0.4.7
|
||||
matplotlib>=3.7.1
|
||||
imageio>=2.26.0
|
||||
imageio-ffmpeg>=0.4.8
|
||||
ffmpy>=0.3.0
|
||||
# Exclude badly numbered Python2 version of nvidia-ml-py
|
||||
nvidia-ml-py>=11.515,<300
|
||||
tensorflow-probability<0.17
|
||||
typing-extensions>=4.0.0
|
||||
pywin32>=228 ; sys_platform == "win32"
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
protobuf>= 3.19.0,<3.20.0 # TF has started pulling in incompatible protobuf
|
||||
# Pinned TF probability doesn't work with numpy >= 1.24
|
||||
numpy>=1.21.0,<1.24.0; python_version < '3.8'
|
||||
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
|
||||
tensorflow-macos>=2.8.0,<2.11.0
|
||||
tensorflow-deps>=2.8.0,<2.11.0
|
||||
tensorflow-metal>=0.4.0,<0.7.0
|
||||
libblas # Conda only
|
||||
-r _requirements_base.txt
|
||||
tensorflow-macos>=2.10.0,<2.11.0
|
||||
tensorflow-deps>=2.10.0,<2.11.0
|
||||
tensorflow-metal>=0.6.0,<0.7.0
|
||||
# These next 2 should have been installed, but some users complain of errors
|
||||
decorator
|
||||
cloudpickle
|
||||
|
|
|
@ -1,5 +1,2 @@
|
|||
-r _requirements_base.txt
|
||||
# Pinned TF probability doesn't work with numpy >= 1.24
|
||||
numpy>=1.21.0,<1.24.0; python_version < '3.8'
|
||||
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
|
||||
tensorflow-cpu>=2.7.0,<2.11.0
|
||||
tensorflow-cpu>=2.10.0,<2.11.0
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
-r _requirements_base.txt
|
||||
# Pinned TF probability doesn't work with numpy >= 1.24
|
||||
numpy>=1.21.0,<1.24.0; python_version < '3.8'
|
||||
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
|
||||
tensorflow-cpu>=2.10.0,<2.11.0
|
||||
tensorflow-directml-plugin
|
||||
comtypes
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
-r _requirements_base.txt
|
||||
# Pinned TF probability doesn't work with numpy >= 1.24
|
||||
numpy>=1.21.0,<1.24.0; python_version < '3.8'
|
||||
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
|
||||
tensorflow-gpu>=2.7.0,<2.11.0
|
||||
# Exclude badly numbered Python2 version of nvidia-ml-py
|
||||
nvidia-ml-py>=11.525,<300
|
||||
pynvx==1.0.0 ; sys_platform == "darwin"
|
||||
tensorflow>=2.10.0,<2.11.0
|
||||
|
|
|
@ -1,5 +1,2 @@
|
|||
-r _requirements_base.txt
|
||||
# Pinned TF probability doesn't work with numpy >= 1.24
|
||||
numpy>=1.21.0,<1.24.0; python_version < '3.8'
|
||||
numpy>=1.22.0,<1.24.0; python_version >= '3.8'
|
||||
tensorflow-rocm>=2.10.0,<2.11.0
|
||||
|
|
|
@ -26,13 +26,9 @@ from lib.utils import FaceswapError, get_folder, get_image_paths
|
|||
from plugins.extract.pipeline import Extractor, ExtractMedia
|
||||
from plugins.plugin_loader import PluginLoader
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import get_args, Literal
|
||||
else:
|
||||
from typing import get_args, Literal
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from collections.abc import Callable
|
||||
from plugins.convert.writer._base import Output
|
||||
from plugins.train.model._base import ModelBase
|
||||
from lib.align.aligned_face import CenteringType
|
||||
|
@ -61,8 +57,8 @@ class ConvertItem:
|
|||
The swapped faces returned from the model's predict function
|
||||
"""
|
||||
inbound: ExtractMedia
|
||||
feed_faces: T.List[AlignedFace] = field(default_factory=list)
|
||||
reference_faces: T.List[AlignedFace] = field(default_factory=list)
|
||||
feed_faces: list[AlignedFace] = field(default_factory=list)
|
||||
reference_faces: list[AlignedFace] = field(default_factory=list)
|
||||
swapped_faces: np.ndarray = np.array([])
|
||||
|
||||
|
||||
|
@ -307,8 +303,8 @@ class DiskIO():
|
|||
# Extractor for on the fly detection
|
||||
self._extractor = self._load_extractor()
|
||||
|
||||
self._queues: T.Dict[Literal["load", "save"], EventQueue] = {}
|
||||
self._threads: T.Dict[Literal["load", "save"], MultiThread] = {}
|
||||
self._queues: dict[T.Literal["load", "save"], EventQueue] = {}
|
||||
self._threads: dict[T.Literal["load", "save"], MultiThread] = {}
|
||||
self._init_threads()
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
@ -324,13 +320,13 @@ class DiskIO():
|
|||
return self._writer.config.get("draw_transparent", False)
|
||||
|
||||
@property
|
||||
def pre_encode(self) -> T.Optional[T.Callable[[np.ndarray], T.List[bytes]]]:
|
||||
def pre_encode(self) -> Callable[[np.ndarray], list[bytes]] | None:
|
||||
""" python function: Selected writer's pre-encode function, if it has one,
|
||||
otherwise ``None`` """
|
||||
dummy = np.zeros((20, 20, 3), dtype="uint8")
|
||||
test = self._writer.pre_encode(dummy)
|
||||
retval: T.Optional[T.Callable[[np.ndarray],
|
||||
T.List[bytes]]] = None if test is None else self._writer.pre_encode
|
||||
retval: Callable[[np.ndarray],
|
||||
list[bytes]] | None = None if test is None else self._writer.pre_encode
|
||||
logger.debug("Writer pre_encode function: %s", retval)
|
||||
return retval
|
||||
|
||||
|
@ -384,7 +380,7 @@ class DiskIO():
|
|||
return PluginLoader.get_converter("writer", self._args.writer)(*args,
|
||||
configfile=configfile)
|
||||
|
||||
def _get_frame_ranges(self) -> T.Optional[T.List[T.Tuple[int, int]]]:
|
||||
def _get_frame_ranges(self) -> list[tuple[int, int]] | None:
|
||||
""" Obtain the frame ranges that are to be converted.
|
||||
|
||||
If frame ranges have been specified, then split the command line formatted arguments into
|
||||
|
@ -422,7 +418,7 @@ class DiskIO():
|
|||
logger.debug("frame ranges: %s", retval)
|
||||
return retval
|
||||
|
||||
def _load_extractor(self) -> T.Optional[Extractor]:
|
||||
def _load_extractor(self) -> Extractor | None:
|
||||
""" Load the CV2-DNN Face Extractor Chain.
|
||||
|
||||
For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU.
|
||||
|
@ -467,12 +463,12 @@ class DiskIO():
|
|||
Creates the load and save queues and the load and save threads. Starts the threads.
|
||||
"""
|
||||
logger.debug("Initializing DiskIO Threads")
|
||||
for task in get_args(Literal["load", "save"]):
|
||||
for task in T.get_args(T.Literal["load", "save"]):
|
||||
self._add_queue(task)
|
||||
self._start_thread(task)
|
||||
logger.debug("Initialized DiskIO Threads")
|
||||
|
||||
def _add_queue(self, task: Literal["load", "save"]) -> None:
|
||||
def _add_queue(self, task: T.Literal["load", "save"]) -> None:
|
||||
""" Add the queue to queue_manager and to :attr:`self._queues` for the given task.
|
||||
|
||||
Parameters
|
||||
|
@ -490,7 +486,7 @@ class DiskIO():
|
|||
self._queues[task] = queue_manager.get_queue(q_name)
|
||||
logger.debug("Added queue for task: '%s'", task)
|
||||
|
||||
def _start_thread(self, task: Literal["load", "save"]) -> None:
|
||||
def _start_thread(self, task: T.Literal["load", "save"]) -> None:
|
||||
""" Create the thread for the given task, add it it :attr:`self._threads` and start it.
|
||||
|
||||
Parameters
|
||||
|
@ -571,7 +567,7 @@ class DiskIO():
|
|||
logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore
|
||||
return skipframe
|
||||
|
||||
def _get_detected_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]:
|
||||
def _get_detected_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
|
||||
""" Return the detected faces for the given image.
|
||||
|
||||
If we have an alignments file, then the detected faces are created from that file. If
|
||||
|
@ -597,7 +593,7 @@ class DiskIO():
|
|||
logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore
|
||||
return detected_faces
|
||||
|
||||
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> T.List[DetectedFace]:
|
||||
def _alignments_faces(self, frame_name: str, image: np.ndarray) -> list[DetectedFace]:
|
||||
""" Return detected faces from an alignments file.
|
||||
|
||||
Parameters
|
||||
|
@ -644,7 +640,7 @@ class DiskIO():
|
|||
tqdm.write(f"No alignment found for {frame_name}, skipping")
|
||||
return have_alignments
|
||||
|
||||
def _detect_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]:
|
||||
def _detect_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]:
|
||||
""" Extract the face from a frame for On-The-Fly conversion.
|
||||
|
||||
Pulls detected faces out of the Extraction pipeline.
|
||||
|
@ -779,7 +775,7 @@ class Predict():
|
|||
""" int: The size in pixels of the Faceswap model output. """
|
||||
return self._sizes["output"]
|
||||
|
||||
def _get_io_sizes(self) -> T.Dict[str, int]:
|
||||
def _get_io_sizes(self) -> dict[str, int]:
|
||||
""" Obtain the input size and output size of the model.
|
||||
|
||||
Returns
|
||||
|
@ -896,9 +892,9 @@ class Predict():
|
|||
"""
|
||||
faces_seen = 0
|
||||
consecutive_no_faces = 0
|
||||
batch: T.List[ConvertItem] = []
|
||||
batch: list[ConvertItem] = []
|
||||
while True:
|
||||
item: T.Union[Literal["EOF"], ConvertItem] = self._in_queue.get()
|
||||
item: T.Literal["EOF"] | ConvertItem = self._in_queue.get()
|
||||
if item == "EOF":
|
||||
logger.debug("EOF Received")
|
||||
if batch: # Process out any remaining items
|
||||
|
@ -938,7 +934,7 @@ class Predict():
|
|||
self._out_queue.put("EOF")
|
||||
logger.debug("Load queue complete")
|
||||
|
||||
def _process_batch(self, batch: T.List[ConvertItem], faces_seen: int):
|
||||
def _process_batch(self, batch: list[ConvertItem], faces_seen: int):
|
||||
""" Predict faces on the given batch of images and queue out to patch thread
|
||||
|
||||
Parameters
|
||||
|
@ -1001,7 +997,7 @@ class Predict():
|
|||
logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore
|
||||
|
||||
@staticmethod
|
||||
def _compile_feed_faces(feed_faces: T.List[AlignedFace]) -> np.ndarray:
|
||||
def _compile_feed_faces(feed_faces: list[AlignedFace]) -> np.ndarray:
|
||||
""" Compile a batch of faces for feeding into the Predictor.
|
||||
|
||||
Parameters
|
||||
|
@ -1020,7 +1016,7 @@ class Predict():
|
|||
logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore
|
||||
return retval
|
||||
|
||||
def _predict(self, feed_faces: np.ndarray, batch_size: T.Optional[int] = None) -> np.ndarray:
|
||||
def _predict(self, feed_faces: np.ndarray, batch_size: int | None = None) -> np.ndarray:
|
||||
""" Run the Faceswap models' prediction function.
|
||||
|
||||
Parameters
|
||||
|
@ -1045,7 +1041,7 @@ class Predict():
|
|||
logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore
|
||||
|
||||
inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size)
|
||||
predicted: T.List[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
|
||||
predicted: list[np.ndarray] = inbound if isinstance(inbound, list) else [inbound]
|
||||
|
||||
if self._model.color_order.lower() == "rgb":
|
||||
predicted[0] = predicted[0][..., ::-1]
|
||||
|
@ -1062,7 +1058,7 @@ class Predict():
|
|||
logger.trace("Final shape: %s", retval.shape) # type:ignore
|
||||
return retval
|
||||
|
||||
def _queue_out_frames(self, batch: T.List[ConvertItem], swapped_faces: np.ndarray) -> None:
|
||||
def _queue_out_frames(self, batch: list[ConvertItem], swapped_faces: np.ndarray) -> None:
|
||||
""" Compile the batch back to original frames and put to the Out Queue.
|
||||
|
||||
For batching, faces are split away from their frames. This compiles all detected faces
|
||||
|
@ -1108,7 +1104,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
|
|||
"""
|
||||
def __init__(self,
|
||||
arguments: Namespace,
|
||||
input_images: T.List[np.ndarray],
|
||||
input_images: list[np.ndarray],
|
||||
alignments: Alignments) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
self._args = arguments
|
||||
|
@ -1131,7 +1127,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
|
|||
self._alignments.filter_faces(accept_dict, filter_out=False)
|
||||
logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count)
|
||||
|
||||
def _get_face_metadata(self) -> T.Dict[str, T.List[int]]:
|
||||
def _get_face_metadata(self) -> dict[str, list[int]]:
|
||||
""" Check for the existence of an aligned directory for identifying which faces in the
|
||||
target frames should be swapped. If it exists, scan the folder for face's metadata
|
||||
|
||||
|
@ -1140,7 +1136,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods
|
|||
dict
|
||||
Dictionary of source frame names with a list of associated face indices to be skipped
|
||||
"""
|
||||
retval: T.Dict[str, T.List[int]] = {}
|
||||
retval: dict[str, list[int]] = {}
|
||||
input_aligned_dir = self._args.input_aligned_dir
|
||||
|
||||
if input_aligned_dir is None:
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
""" Main entry point to the extract process of FaceSwap """
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as T
|
||||
|
||||
from argparse import Namespace
|
||||
from multiprocessing import Process
|
||||
from typing import List, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
@ -20,7 +20,7 @@ from lib.utils import get_folder, _image_extensions, _video_extensions
|
|||
from plugins.extract.pipeline import Extractor, ExtractMedia
|
||||
from scripts.fsmedia import Alignments, PostProcess, finalize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if T.TYPE_CHECKING:
|
||||
from lib.align.alignments import PNGHeaderAlignmentsDict
|
||||
|
||||
# tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO?
|
||||
|
@ -75,7 +75,7 @@ class Extract(): # pylint:disable=too-few-public-methods
|
|||
self._args.nfilter,
|
||||
self._extractor)
|
||||
|
||||
def _get_input_locations(self) -> List[str]:
|
||||
def _get_input_locations(self) -> list[str]:
|
||||
""" Obtain the full path to input locations. Will be a list of locations if batch mode is
|
||||
selected, or a containing a single location if batch mode is not selected.
|
||||
|
||||
|
@ -194,8 +194,8 @@ class Filter():
|
|||
"""
|
||||
def __init__(self,
|
||||
threshold: float,
|
||||
filter_files: Optional[List[str]],
|
||||
nfilter_files: Optional[List[str]],
|
||||
filter_files: list[str] | None,
|
||||
nfilter_files: list[str] | None,
|
||||
extractor: Extractor) -> None:
|
||||
logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s "
|
||||
"extractor: %s)", self.__class__.__name__, threshold, filter_files,
|
||||
|
@ -208,8 +208,8 @@ class Filter():
|
|||
logger.debug("Filter not selected. Exiting %s", self.__class__.__name__)
|
||||
return
|
||||
|
||||
self._embeddings: List[np.ndarray] = [np.array([]) for _ in self._filter_files]
|
||||
self._nembeddings: List[np.ndarray] = [np.array([]) for _ in self._nfilter_files]
|
||||
self._embeddings: list[np.ndarray] = [np.array([]) for _ in self._filter_files]
|
||||
self._nembeddings: list[np.ndarray] = [np.array([]) for _ in self._nfilter_files]
|
||||
self._extractor = extractor
|
||||
|
||||
self._get_embeddings()
|
||||
|
@ -243,7 +243,7 @@ class Filter():
|
|||
return retval
|
||||
|
||||
@classmethod
|
||||
def _files_from_folder(cls, input_location: List[str]) -> List[str]:
|
||||
def _files_from_folder(cls, input_location: list[str]) -> list[str]:
|
||||
""" Test whether the input location is a folder and if so, return the list of contained
|
||||
image files, otherwise return the original input location
|
||||
|
||||
|
@ -274,8 +274,8 @@ class Filter():
|
|||
return retval
|
||||
|
||||
def _validate_inputs(self,
|
||||
filter_files: Optional[List[str]],
|
||||
nfilter_files: Optional[List[str]]) -> Tuple[List[str], List[str]]:
|
||||
filter_files: list[str] | None,
|
||||
nfilter_files: list[str] | None) -> tuple[list[str], list[str]]:
|
||||
""" Validates that the given filter/nfilter files exist, are image files and are unique
|
||||
|
||||
Parameters
|
||||
|
@ -293,7 +293,7 @@ class Filter():
|
|||
List of full paths to nfilter files
|
||||
"""
|
||||
error = False
|
||||
retval: List[List[str]] = []
|
||||
retval: list[list[str]] = []
|
||||
|
||||
for files in (filter_files, nfilter_files):
|
||||
filt_files = [] if files is None else self._files_from_folder(files)
|
||||
|
@ -322,7 +322,7 @@ class Filter():
|
|||
return filters, nfilters
|
||||
|
||||
@classmethod
|
||||
def _identity_from_extracted(cls, filename) -> Tuple[np.ndarray, bool]:
|
||||
def _identity_from_extracted(cls, filename) -> tuple[np.ndarray, bool]:
|
||||
""" Test whether the given image is a faceswap extracted face and contains identity
|
||||
information. If so, return the identity embedding
|
||||
|
||||
|
@ -404,7 +404,7 @@ class Filter():
|
|||
embeddings[idx] = identities
|
||||
return
|
||||
|
||||
def _identity_from_extractor(self, file_list: List[str], aligned: List[str]) -> None:
|
||||
def _identity_from_extractor(self, file_list: list[str], aligned: list[str]) -> None:
|
||||
""" Obtain the identity embeddings from the extraction pipeline
|
||||
|
||||
Parameters
|
||||
|
@ -425,7 +425,7 @@ class Filter():
|
|||
|
||||
for phase in range(self._extractor.passes):
|
||||
is_final = self._extractor.final_pass
|
||||
detected_faces: Dict[str, ExtractMedia] = {}
|
||||
detected_faces: dict[str, ExtractMedia] = {}
|
||||
self._extractor.launch()
|
||||
desc = "Obtaining reference face Identity"
|
||||
if self._extractor.passes > 1:
|
||||
|
@ -450,8 +450,8 @@ class Filter():
|
|||
|
||||
def _get_embeddings(self) -> None:
|
||||
""" Obtain the embeddings for the given filter lists """
|
||||
needs_extraction: List[str] = []
|
||||
aligned: List[str] = []
|
||||
needs_extraction: list[str] = []
|
||||
aligned: list[str] = []
|
||||
|
||||
for files, embed in zip((self._filter_files, self._nfilter_files),
|
||||
(self._embeddings, self._nembeddings)):
|
||||
|
@ -494,14 +494,14 @@ class PipelineLoader():
|
|||
image files that exist in :attr:`path` that are aligned faceswap images
|
||||
"""
|
||||
def __init__(self,
|
||||
path: Union[str, List[str]],
|
||||
path: str | list[str],
|
||||
extractor: Extractor,
|
||||
aligned_filenames: Optional[List[str]] = None) -> None:
|
||||
aligned_filenames: list[str] | None = None) -> None:
|
||||
logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)",
|
||||
self.__class__.__name__, path, extractor, aligned_filenames)
|
||||
self._images = ImagesLoader(path, fast_count=True)
|
||||
self._extractor = extractor
|
||||
self._threads: List[MultiThread] = []
|
||||
self._threads: list[MultiThread] = []
|
||||
self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
|
@ -512,7 +512,7 @@ class PipelineLoader():
|
|||
return self._images.is_video
|
||||
|
||||
@property
|
||||
def file_list(self) -> List[str]:
|
||||
def file_list(self) -> list[str]:
|
||||
""" list: A full list of files in the source location. If the input is a video
|
||||
then this is a list of dummy filenames as corresponding to an alignments file """
|
||||
return self._images.file_list
|
||||
|
@ -523,7 +523,7 @@ class PipelineLoader():
|
|||
items that are to be skipped from the :attr:`skip_list`)"""
|
||||
return self._images.process_count
|
||||
|
||||
def add_skip_list(self, skip_list: List[int]) -> None:
|
||||
def add_skip_list(self, skip_list: list[int]) -> None:
|
||||
""" Add a skip list to the :class:`ImagesLoader`
|
||||
|
||||
Parameters
|
||||
|
@ -538,7 +538,7 @@ class PipelineLoader():
|
|||
""" Launch the image loading pipeline """
|
||||
self._threaded_redirector("load")
|
||||
|
||||
def reload(self, detected_faces: Dict[str, ExtractMedia]) -> None:
|
||||
def reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
|
||||
""" Reload images for multiple pipeline passes """
|
||||
self._threaded_redirector("reload", (detected_faces, ))
|
||||
|
||||
|
@ -552,7 +552,7 @@ class PipelineLoader():
|
|||
for thread in self._threads:
|
||||
thread.join()
|
||||
|
||||
def _threaded_redirector(self, task: str, io_args: Optional[tuple] = None) -> None:
|
||||
def _threaded_redirector(self, task: str, io_args: tuple | None = None) -> None:
|
||||
""" Redirect image input/output tasks to relevant queues in background thread
|
||||
|
||||
Parameters
|
||||
|
@ -587,7 +587,7 @@ class PipelineLoader():
|
|||
load_queue.put("EOF")
|
||||
logger.debug("Load Images: Complete")
|
||||
|
||||
def _reload(self, detected_faces: Dict[str, ExtractMedia]) -> None:
|
||||
def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None:
|
||||
""" Reload the images and pair to detected face
|
||||
|
||||
When the extraction pipeline is running in serial mode, images are reloaded from disk,
|
||||
|
@ -652,7 +652,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
|
|||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def _save_interval(self) -> Optional[int]:
|
||||
def _save_interval(self) -> int | None:
|
||||
""" int: The number of frames to be processed between each saving of the alignments file if
|
||||
it has been provided, otherwise ``None`` """
|
||||
if hasattr(self._args, "save_interval"):
|
||||
|
@ -718,7 +718,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
|
|||
as_bytes=True)
|
||||
for phase in range(self._extractor.passes):
|
||||
is_final = self._extractor.final_pass
|
||||
detected_faces: Dict[str, ExtractMedia] = {}
|
||||
detected_faces: dict[str, ExtractMedia] = {}
|
||||
self._extractor.launch()
|
||||
self._loader.check_thread_error()
|
||||
ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text
|
||||
|
@ -774,7 +774,7 @@ class _Extract(): # pylint:disable=too-few-public-methods
|
|||
if not self._verify_output and faces_count > 1:
|
||||
self._verify_output = True
|
||||
|
||||
def _output_faces(self, saver: Optional[ImagesSaver], extract_media: ExtractMedia) -> None:
|
||||
def _output_faces(self, saver: ImagesSaver | None, extract_media: ExtractMedia) -> None:
|
||||
""" Output faces to save thread
|
||||
|
||||
Set the face filename based on the frame name and put the face to the
|
||||
|
@ -798,14 +798,14 @@ class _Extract(): # pylint:disable=too-few-public-methods
|
|||
output_filename = f"{filename}_{real_face_id}.png"
|
||||
aligned = face.aligned.face
|
||||
assert aligned is not None
|
||||
meta: PNGHeaderDict = dict(
|
||||
alignments=face.to_png_meta(),
|
||||
source=dict(alignments_version=self._alignments.version,
|
||||
original_filename=output_filename,
|
||||
face_index=real_face_id,
|
||||
source_filename=os.path.basename(extract_media.filename),
|
||||
source_is_video=self._loader.is_video,
|
||||
source_frame_dims=extract_media.image_size))
|
||||
meta: PNGHeaderDict = {
|
||||
"alignments": face.to_png_meta(),
|
||||
"source": {"alignments_version": self._alignments.version,
|
||||
"original_filename": output_filename,
|
||||
"face_index": real_face_id,
|
||||
"source_filename": os.path.basename(extract_media.filename),
|
||||
"source_is_video": self._loader.is_video,
|
||||
"source_frame_dims": extract_media.image_size}}
|
||||
image = encode_image(aligned, ".png", metadata=meta)
|
||||
|
||||
sub_folder = extract_media.sub_folders[face_id]
|
||||
|
@ -820,6 +820,6 @@ class _Extract(): # pylint:disable=too-few-public-methods
|
|||
continue
|
||||
final_faces.append(face.to_alignment())
|
||||
|
||||
self._alignments.data[os.path.basename(extract_media.filename)] = dict(faces=final_faces,
|
||||
video_meta={})
|
||||
self._alignments.data[os.path.basename(extract_media.filename)] = {"faces": final_faces,
|
||||
"video_meta": {}}
|
||||
del extract_media
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue