From 6a3b674bef8b585347a51ffdd8765d0a1eaa15fc Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Tue, 27 Jun 2023 11:27:47 +0100 Subject: [PATCH] Rebase code (#1326) * Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test --- .github/workflows/pytest.yml | 87 ++- .install/linux/faceswap_setup_x64.sh | 4 +- .install/windows/install.nsi | 2 +- .readthedocs.yml | 4 +- Dockerfile.cpu | 24 +- Dockerfile.gpu | 42 +- INSTALL.md | 217 ++---- docs/conf.py | 2 +- docs/sphinx_requirements.txt | 36 +- faceswap.py | 5 +- lib/align/aligned_face.py | 364 +++++------ lib/align/alignments.py | 114 ++-- lib/align/detected_face.py | 116 ++-- lib/cli/actions.py | 18 +- lib/cli/args.py | 41 +- lib/cli/launcher.py | 22 +- lib/config.py | 54 +- lib/convert.py | 61 +- lib/gpu_stats/_base.py | 33 +- lib/gpu_stats/apple_silicon.py | 10 +- lib/gpu_stats/cpu.py | 16 +- lib/gpu_stats/directml.py | 36 +- lib/gpu_stats/nvidia.py | 9 +- lib/gpu_stats/nvidia_apple.py | 8 +- lib/gpu_stats/rocm.py | 15 +- lib/gui/analysis/event_reader.py | 84 ++- lib/gui/analysis/stats.py | 79 +-- lib/gui/control_helper.py | 64 +- lib/gui/display_command.py | 18 +- lib/gui/display_graph.py | 45 +- lib/gui/menu.py | 12 +- lib/gui/popup_configure.py | 21 +- lib/gui/popup_session.py | 63 +- lib/gui/utils/config.py | 63 +- lib/gui/utils/file_handler.py | 171 +++-- lib/gui/utils/image.py | 62 +- lib/gui/utils/misc.py | 26 +- lib/image.py | 16 +- lib/keras_utils.py | 2 +- lib/logger.py | 24 +- lib/model/autoclip.py | 55 +- lib/model/losses/feature_loss.py | 21 +- lib/model/losses/loss.py | 20 +- lib/model/losses/perceptual_loss.py | 23 +- lib/model/nets.py | 8 +- lib/model/nn_blocks.py | 36 +- lib/model/session.py | 19 +- lib/multithreading.py | 39 +- lib/queue_manager.py | 3 +- lib/sysinfo.py | 24 +- lib/training/__init__.py | 8 +- lib/training/augmentation.py | 11 +- lib/training/cache.py | 60 +- lib/training/generator.py | 108 ++- lib/training/preview_cv.py | 38 +- lib/training/preview_tk.py | 47 +- lib/utils.py | 53 +- locales/plugins.train._config.pot | 118 ++-- .../ru/LC_MESSAGES/plugins.train._config.mo | Bin 53422 -> 56243 bytes .../ru/LC_MESSAGES/plugins.train._config.po | 107 ++- plugins/convert/mask/mask_blend.py | 29 +- plugins/convert/writer/_base.py | 13 +- plugins/convert/writer/ffmpeg.py | 31 +- plugins/convert/writer/gif.py | 22 +- plugins/convert/writer/opencv.py | 10 +- plugins/convert/writer/pillow.py | 10 +- plugins/extract/_base.py | 87 ++- plugins/extract/align/_base/aligner.py | 50 +- plugins/extract/align/_base/processing.py | 68 +- plugins/extract/align/cv2_dnn.py | 33 +- plugins/extract/align/fan.py | 21 +- plugins/extract/detect/_base.py | 57 +- plugins/extract/detect/mtcnn.py | 50 +- plugins/extract/detect/s3fd.py | 20 +- plugins/extract/mask/_base.py | 23 +- plugins/extract/mask/bisenet_fp.py | 10 +- plugins/extract/mask/components.py | 9 +- plugins/extract/mask/extended.py | 9 +- plugins/extract/mask/unet_dfl.py | 4 +- plugins/extract/mask/vgg_clear.py | 6 +- plugins/extract/mask/vgg_obstructed.py | 4 +- plugins/extract/pipeline.py | 158 +++-- plugins/extract/recognition/_base.py | 46 +- plugins/extract/recognition/vgg_face2.py | 30 +- plugins/plugin_loader.py | 36 +- plugins/train/_config.py | 25 + plugins/train/model/_base/io.py | 25 +- plugins/train/model/_base/model.py | 69 +- plugins/train/model/_base/settings.py | 56 +- plugins/train/model/phaze_a.py | 74 +-- plugins/train/model/phaze_a_defaults.py | 4 +- plugins/train/trainer/_base.py | 166 +++-- requirements/_requirements_base.txt | 25 +- requirements/requirements_apple_silicon.txt | 12 +- requirements/requirements_cpu.txt | 5 +- requirements/requirements_directml.txt | 3 - requirements/requirements_nvidia.txt | 7 +- requirements/requirements_rocm.txt | 3 - scripts/convert.py | 58 +- scripts/extract.py | 76 +-- scripts/fsmedia.py | 56 +- scripts/train.py | 75 +-- setup.cfg | 2 - setup.py | 618 ++++++++++-------- tests/lib/gpu_stats/_base_test.py | 7 +- tests/lib/gui/stats/event_reader_test.py | 13 +- tests/lib/model/optimizers_test.py | 7 - tests/lib/sysinfo_test.py | 6 +- tests/lib/utils_test.py | 39 +- tests/simple_tests.py | 7 +- tests/tools/alignments/media_test.py | 67 +- tests/tools/preview/viewer_test.py | 28 +- tests/utils.py | 7 - tools.py | 8 +- tools/alignments/alignments.py | 36 +- tools/alignments/cli.py | 5 +- tools/alignments/jobs.py | 69 +- tools/alignments/jobs_faces.py | 84 ++- tools/alignments/jobs_frames.py | 66 +- tools/alignments/media.py | 87 ++- tools/mask/mask.py | 54 +- tools/model/cli.py | 4 +- tools/model/model.py | 4 +- tools/preview/cli.py | 4 +- tools/preview/control_panels.py | 73 ++- tools/preview/preview.py | 63 +- tools/preview/viewer.py | 35 +- tools/sort/sort.py | 48 +- tools/sort/sort_methods.py | 127 ++-- tools/sort/sort_methods_aligned.py | 32 +- 130 files changed, 3035 insertions(+), 3028 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 7fc82d3c..badd8f4f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -8,17 +8,85 @@ on: - "**/README.md" jobs: - build_linux: + build_conda: + name: conda (${{ matrix.os }}, ${{ matrix.backend }}) + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest", "macos-latest", "windows-latest"] + backend: ["nvidia", "cpu"] + include: + - os: "ubuntu-latest" + backend: "rocm" + - os: "windows-latest" + backend: "directml" + exclude: + # pynvx does not currently build for Python3.10 and without CUDA it may not build at all + - os: "macos-latest" + backend: "nvidia" + steps: + - uses: actions/checkout@v3 + - name: Set cache date + run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV + - name: Cache conda + uses: actions/cache@v3 + env: + # Increase this value to manually reset cache + CACHE_NUMBER: 1 + REQ_FILE: ./requirements/requirements_${{ matrix.backend }}.txt + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-${{ matrix.backend }}-conda-${{ env.CACHE_NUMBER }}-${{ env.DATE }}-${{ hashFiles('./requirements/requirements.txt', env.REQ_FILE) }} + - name: Set up Conda + uses: conda-incubator/setup-miniconda@v2 + with: + python-version: "3.10" + auto-update-conda: true + activate-environment: faceswap + - name: Conda info + run: conda info && conda list + - name: Install + run: | + python setup.py --installer --${{ matrix.backend }} + pip install flake8 pylint mypy pytest pytest-mock wheel pytest-xvfb + pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --select=E9,F63,F7,F82 --show-source + flake8 . --exit-zero + - name: MyPy Typing + continue-on-error: true + run: | + mypy . + - name: SysInfo + run: python -c "from lib.sysinfo import sysinfo ; print(sysinfo)" + - name: Simple Tests + # These backends will fail as GPU drivers not available + if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml' + run: | + FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/; + - name: End to End Tests + # These backends will fail as GPU drivers not available + # macOS fails on first extract test with 'died with ' + if: matrix.backend != 'rocm' && matrix.backend != 'nvidia' && matrix.backend != 'directml' && matrix.os != 'macos-latest' + run: | + FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py; + build_linux: + name: "pip (ubuntu-latest, ${{ matrix.backend }})" runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.10"] backend: ["cpu"] include: - - kbackend: "tensorflow" - backend: "cpu" + - backend: "cpu" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -33,6 +101,8 @@ jobs: pip install flake8 pylint mypy pytest pytest-mock pytest-xvfb wheel pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools pip install -r ./requirements/requirements_${{ matrix.backend }}.txt + - name: List installed packages + run: pip freeze - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -45,17 +115,18 @@ jobs: mypy . - name: Simple Tests run: | - FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" py.test -v tests/; + FACESWAP_BACKEND="${{ matrix.backend }}" py.test -v tests/; - name: End to End Tests run: | - FACESWAP_BACKEND="${{ matrix.backend }}" KERAS_BACKEND="${{ matrix.kbackend }}" python tests/simple_tests.py; + FACESWAP_BACKEND="${{ matrix.backend }}" python tests/simple_tests.py; build_windows: + name: "pip (windows-latest, ${{ matrix.backend }})" runs-on: windows-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.10"] backend: ["cpu", "directml"] include: - backend: "cpu" @@ -74,6 +145,8 @@ jobs: pip install flake8 pylint mypy pytest pytest-mock wheel pip install types-attrs types-cryptography types-pyOpenSSL types-PyYAML types-setuptools pip install -r ./requirements/requirements_${{ matrix.backend }}.txt + - name: List installed packages + run: pip freeze - name: Set Backend EnvVar run: echo "FACESWAP_BACKEND=${{ matrix.backend }}" | Out-File -FilePath $env:GITHUB_ENV -Append - name: Lint with flake8 diff --git a/.install/linux/faceswap_setup_x64.sh b/.install/linux/faceswap_setup_x64.sh index 1e60b41c..21bbec63 100644 --- a/.install/linux/faceswap_setup_x64.sh +++ b/.install/linux/faceswap_setup_x64.sh @@ -12,7 +12,7 @@ DIR_CONDA="$HOME/miniconda3" CONDA_EXECUTABLE="${DIR_CONDA}/bin/conda" CONDA_TO_PATH=false ENV_NAME="faceswap" -PYENV_VERSION="3.9" +PYENV_VERSION="3.10" DIR_FACESWAP="$HOME/faceswap" VERSION="nvidia" @@ -363,7 +363,7 @@ delete_env() { } create_env() { - # Create Python 3.8 env for faceswap + # Create Python 3.10 env for faceswap delete_env info "Creating Conda Virtual Environment..." yellow ; "$CONDA_EXECUTABLE" create -n "$ENV_NAME" -q python="$PYENV_VERSION" -y diff --git a/.install/windows/install.nsi b/.install/windows/install.nsi index d70ddb5b..d4078078 100644 --- a/.install/windows/install.nsi +++ b/.install/windows/install.nsi @@ -22,7 +22,7 @@ InstallDir $PROFILE\faceswap # Install cli flags !define flagsConda "/S /RegisterPython=0 /AddToPath=0 /D=$PROFILE\MiniConda3" !define flagsRepo "--depth 1 --no-single-branch ${wwwRepo}" -!define flagsEnv "-y python=3.9" +!define flagsEnv "-y python=3.10" # Folders Var ProgramData diff --git a/.readthedocs.yml b/.readthedocs.yml index 2aa3c993..8d199514 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 8b9d2977..0c27ec9b 100755 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,19 +1,19 @@ -FROM tensorflow/tensorflow:2.8.2 +FROM ubuntu:22.04 # To disable tzdata and others from asking for input ENV DEBIAN_FRONTEND noninteractive +ENV FACESWAP_BACKEND cpu -RUN apt-get update -qq -y \ - && apt-get install -y software-properties-common \ - && add-apt-repository -y ppa:jonathonf/ffmpeg-4 \ - && apt-get update -qq -y \ - && apt-get install -y libsm6 libxrender1 libxext-dev python3-tk ffmpeg git \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* +RUN apt-get update -qq -y +RUN apt-get upgrade -y +RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git -COPY ./requirements/_requirements_base.txt /opt/ -RUN pip3 install --upgrade pip -RUN pip3 --no-cache-dir install -r /opt/_requirements_base.txt && rm /opt/_requirements_base.txt +RUN ln -s $(which python3) /usr/local/bin/python + +RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git +WORKDIR "/faceswap" + +RUN python -m pip install --upgrade pip +RUN python -m pip --no-cache-dir install -r ./requirements/requirements_cpu.txt -WORKDIR "/srv" CMD ["/bin/bash"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu index 078875f5..d62e010f 100755 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -1,29 +1,19 @@ -FROM nvidia/cuda:11.7.0-runtime-ubuntu18.04 -ARG DEBIAN_FRONTEND=noninteractive +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 -#install python3.8 -RUN apt-get update -RUN apt-get install software-properties-common -y -RUN add-apt-repository ppa:deadsnakes/ppa -y -RUN apt-get update -RUN apt-get install python3.8 -y -RUN apt-get install python3.8-distutils -y -RUN apt-get install python3.8-tk -y -RUN apt-get install curl -y -RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py -RUN python3.8 get-pip.py -RUN rm get-pip.py +ENV DEBIAN_FRONTEND=noninteractive +ENV FACESWAP_BACKEND nvidia -# install requirements -RUN apt-get install ffmpeg git -y -COPY ./requirements/_requirements_base.txt /opt/ -COPY ./requirements/requirements_nvidia.txt /opt/ -RUN python3.8 -m pip --no-cache-dir install -r /opt/requirements_nvidia.txt && rm /opt/_requirements_base.txt && rm /opt/requirements_nvidia.txt +RUN apt-get update -qq -y +RUN apt-get upgrade -y +RUN apt-get install -y libgl1 libglib2.0-0 python3 python3-pip python3-tk git -RUN python3.8 -m pip install jupyter matplotlib tqdm -RUN python3.8 -m pip install jupyter_http_over_ws -RUN jupyter serverextension enable --py jupyter_http_over_ws -RUN alias python=python3.8 -RUN echo "alias python=python3.8" >> /root/.bashrc -WORKDIR "/notebooks" -CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"] +RUN ln -s $(which python3) /usr/local/bin/python + +RUN git clone --depth 1 --no-single-branch https://github.com/deepfakes/faceswap.git +WORKDIR "/faceswap" + +RUN python -m pip install --upgrade pip +RUN python -m pip install --upgrade pip +RUN python -m pip --no-cache-dir install -r ./requirements/requirements_nvidia.txt + +CMD ["/bin/bash"] diff --git a/INSTALL.md b/INSTALL.md index cae8acdd..24807eb7 100755 --- a/INSTALL.md +++ b/INSTALL.md @@ -39,12 +39,9 @@ - [Setup](#setup-2) - [About some of the options](#about-some-of-the-options) - [Docker Install Guide](#docker-install-guide) - - [Docker General](#docker-general) - - [CUDA with Docker in 20 minutes.](#cuda-with-docker-in-20-minutes) - - [CUDA with Docker on Arch Linux](#cuda-with-docker-on-arch-linux) - - [Install docker](#install-docker) - - [A successful setup log, without docker.](#a-successful-setup-log-without-docker) - - [Run the project](#run-the-project) + - [Docker CPU](#docker-cpu) + - [Docker Nvidia](#docker-nvidia) +- [Run the project](#run-the-project) - [Notes](#notes) # Prerequisites @@ -115,7 +112,7 @@ Reboot your PC, so that everything you have just installed gets registered. - Select "Create" at the bottom - In the pop up: - Give it the name: faceswap - - **IMPORTANT**: Select python version 3.8 + - **IMPORTANT**: Select python version 3.10 - Hit "Create" (NB: This may take a while as it will need to download Python) ![Anaconda virtual env setup](https://i.imgur.com/CLIDDfa.png) @@ -195,7 +192,7 @@ $ source ~/miniforge3/bin/activate ## Setup ### Create and Activate the Environment ```sh -$ conda create --name faceswap python=3.9 +$ conda create --name faceswap python=3.10 $ conda activate faceswap ``` @@ -225,7 +222,7 @@ Obtain git for your distribution from the [git website](https://git-scm.com/down The recommended install method is to use a Conda3 Environment as this will handle the installation of Nvidia's CUDA and cuDNN straight into your Conda Environment. This is by far the easiest and most reliable way to setup the project. - MiniConda3 is recommended: [MiniConda3](https://docs.conda.io/en/latest/miniconda.html) -Alternatively you can install Python (>= 3.7-3.9 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu). +Alternatively you can install Python (3.10 64-bit) for your distribution (links below.) If you go down this route and are using an Nvidia GPU you should install CUDA (https://developer.nvidia.com/cuda-zone) and cuDNN (https://developer.nvidia.com/cudnn). for your system. If you do not plan to build Tensorflow yourself, make sure you install the correct Cuda and cuDNN package for the currently installed version of Tensorflow (Current release: Tensorflow 2.9. Release v1.0: Tensorflow 1.15). You can check for the compatible versions here: (https://www.tensorflow.org/install/source#gpu). - Python distributions: - apt/yum install python3 (Linux) - [Installer](https://www.python.org/downloads/release/python-368/) (Windows) @@ -260,153 +257,83 @@ If setup fails for any reason you can still manually install the packages listed # Docker Install Guide -## Docker General -
- Click to expand! +This Faceswap repo contains Docker build scripts for CPU and Nvidia backends. The scripts will set up a Docker container for you and install the latest version of the Faceswap software. - ### CUDA with Docker in 20 minutes. - - 1. Install Docker - https://www.docker.com/community-edition +You must first ensure that Docker is installed and running on your system. Follow the guide for downloading and installing Docker from their website: - 2. Install Nvidia-Docker & Restart Docker Service - https://github.com/NVIDIA/nvidia-docker + - https://www.docker.com/get-started - 3. Build Docker Image For faceswap - - ```bash - docker build -t deepfakes-gpu -f Dockerfile.gpu . - ``` +Once Docker is installed and running, follow the relevant steps for your chosen backend +## Docker CPU +To run the CPU version of Faceswap follow these steps: - 4. Mount faceswap volume and Run it - a). without `gui.tools.py` gui not working. - - ```bash - nvidia-docker run --rm -it -p 8888:8888 \ - --hostname faceswap-gpu --name faceswap-gpu \ - -v /opt/faceswap:/srv \ - deepfakes-gpu - ``` - - b). with gui. tools.py gui working. - -Enable local access to X11 server - -```bash -xhost +local: +1. Build the Docker image For faceswap: ``` - -Enable nvidia device if working under bumblebee - -```bash -echo ON > /proc/acpi/bbswitch +docker build \ +-t faceswap-cpu \ +https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.cpu ``` +2. Launch and enter the Faceswap container: -Create container -```bash -nvidia-docker run -p 8888:8888 \ - --hostname faceswap-gpu --name faceswap-gpu \ - -v /opt/faceswap:/srv \ - -v /tmp/.X11-unix:/tmp/.X11-unix \ - -e DISPLAY=unix$DISPLAY \ - -e AUDIO_GID=`getent group audio | cut -d: -f3` \ - -e VIDEO_GID=`getent group video | cut -d: -f3` \ - -e GID=`id -g` \ - -e UID=`id -u` \ - deepfakes-gpu + a. For the **headless/command line** version of Faceswap run: + ``` + docker run --rm -it faceswap-cpu + ``` + You can then execute faceswap the standard way: + ``` + python faceswap.py --help + ``` + b. For the **GUI** version of Faceswap run: + ``` + xhost +local: && \ + docker run --rm -it \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY=${DISPLAY} \ + faceswap-cpu + ``` + You can then launch the GUI with + ``` + python faceswap.py gui + ``` + ## Docker Nvidia +To build the NVIDIA GPU version of Faceswap, follow these steps: +1. Nvidia Docker builds need extra resources to provide the Docker container with access to your GPU. + + a. Follow the instructions to install and apply the `Nvidia Container Toolkit` for your distribution from: + - https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html + + b. If Docker is already running, restart it to pick up the changes made by the Nvidia Container Toolkit. + +2. Build the Docker image For faceswap ``` - -Open a new terminal to interact with the project - -```bash -docker exec -it deepfakes-gpu /bin/bash +docker build \ +-t faceswap-gpu \ +https://raw.githubusercontent.com/deepfakes/faceswap/master/Dockerfile.gpu ``` +1. Launch and enter the Faceswap container: -Launch deepfakes gui (Answer 3 for NVIDIA at the prompt) - -```bash -python3.8 /srv/faceswap.py gui -``` -
- -## CUDA with Docker on Arch Linux - -
- Click to expand! - -### Install docker - -```bash -sudo pacman -S docker -``` - -The steps are same but Arch linux doesn't use nvidia-docker - -create container - -```bash -docker run -p 8888:8888 --gpus all --privileged -v /dev:/dev \ - --hostname faceswap-gpu --name faceswap-gpu \ - -v /mnt/hdd2/faceswap:/srv \ - -v /tmp/.X11-unix:/tmp/.X11-unix \ - -e DISPLAY=unix$DISPLAY \ - -e AUDIO_GID=`getent group audio | cut -d: -f3` \ - -e VIDEO_GID=`getent group video | cut -d: -f3` \ - -e GID=`id -g` \ - -e UID=`id -u` \ - deepfakes-gpu -``` - -Open a new terminal to interact with the project - -```bash -docker exec -it deepfakes-gpu /bin/bash -``` - -Launch deepfakes gui (Answer 3 for NVIDIA at the prompt) - -**With `gui.tools.py` gui working.** - Enable local access to X11 server - - ```bash -xhost +local: -``` - - ```bash - python3.8 /srv/faceswap.py gui - ``` - -
- ---- -## A successful setup log, without docker. -``` -INFO The tool provides tips for installation - and installs required python packages -INFO Setup in Linux 4.14.39-1-MANJARO -INFO Installed Python: 3.7.5 64bit -INFO Installed PIP: 10.0.1 -Enable Docker? [Y/n] n -INFO Docker Disabled -Enable CUDA? [Y/n] -INFO CUDA Enabled -INFO CUDA version: 9.1 -INFO cuDNN version: 7 -WARNING Tensorflow has no official prebuild for CUDA 9.1 currently. - To continue, You have to build your own tensorflow-gpu. - Help: https://www.tensorflow.org/install/install_sources -Are System Dependencies met? [y/N] y -INFO Installing Missing Python Packages... -INFO Installing tensorflow-gpu -...... -INFO Installing tqdm -INFO Installing matplotlib -INFO All python3 dependencies are met. - You are good to go. -``` - -## Run the project + a. For the **headless/command line** version of Faceswap run: + ``` + docker run --runtime=nvidia --rm -it faceswap-gpu + ``` + You can then execute faceswap the standard way: + ``` + python faceswap.py --help + ``` + b. For the **GUI** version of Faceswap run: + ``` + xhost +local: && \ + docker run --runtime=nvidia --rm -it \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY=${DISPLAY} \ + faceswap-gpu + ``` + You can then launch the GUI with + ``` + python faceswap.py gui + ``` +# Run the project Once all these requirements are installed, you can attempt to run the faceswap tools. Use the `-h` or `--help` options for a list of options. ```bash diff --git a/docs/conf.py b/docs/conf.py index 35c7a226..f5476558 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath('../')) sys.setrecursionlimit(1500) -MOCK_MODULES = ["plaidml", "pynvx", "ctypes.windll", "comtypes"] +MOCK_MODULES = ["pynvx", "ctypes.windll", "comtypes"] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/docs/sphinx_requirements.txt b/docs/sphinx_requirements.txt index b4cb179c..5c28e7c7 100755 --- a/docs/sphinx_requirements.txt +++ b/docs/sphinx_requirements.txt @@ -1,25 +1,21 @@ # NB Do not install from this requirements file # It is for documentation purposes only -sphinx==5.0.2 -sphinx_rtd_theme==1.0.0 -tqdm==4.64 -psutil==5.8.0 -numexpr>=2.8.3 -numpy>=1.18.0 -opencv-python>=4.5.5.0 -pillow==8.3.1 -scikit-learn>=1.0.2 -fastcluster>=1.2.4 -matplotlib==3.5.1 -numexpr -imageio==2.9.0 -imageio-ffmpeg==0.4.7 -ffmpy==0.2.3 -nvidia-ml-py<11.515 -plaidml==0.7.0 +sphinx==7.0.1 +sphinx_rtd_theme==1.2.2 +tqdm==4.65 +psutil==5.9.0 +numexpr>=2.8.4 +numpy>=1.25.0 +opencv-python>=4.7.0.0 +pillow==9.4.0 +scikit-learn>=1.2.2 +fastcluster>=1.2.6 +matplotlib==3.7.1 +imageio==2.31.1 +imageio-ffmpeg==0.4.8 +ffmpy==0.3.0 +nvidia-ml-py==11.525 pytest==7.2.0 pytest-mock==3.10.0 -tensorflow>=2.8.0,<2.9.0 -tensorflow_probability<0.17 -typing-extensions>=4.0.0 +tensorflow>=2.10.0,<2.11.0 diff --git a/faceswap.py b/faceswap.py index 3b6777e8..1189f2e5 100755 --- a/faceswap.py +++ b/faceswap.py @@ -16,9 +16,8 @@ from lib.config import generate_configs # pylint:disable=wrong-import-position _LANG = gettext.translation("faceswap", localedir="locales", fallback=True) _ = _LANG.gettext -if sys.version_info < (3, 7): - raise ValueError("This program requires at least python3.7") - +if sys.version_info < (3, 10): + raise ValueError("This program requires at least python 3.10") _PARSER = cli_args.FullHelpArgumentParser() diff --git a/lib/align/aligned_face.py b/lib/align/aligned_face.py index 5c9ff88a..a8e997fe 100644 --- a/lib/align/aligned_face.py +++ b/lib/align/aligned_face.py @@ -3,22 +3,14 @@ from dataclasses import dataclass, field import logging -import sys +import typing as T from threading import Lock -from typing import cast, Dict, Optional, Tuple - import cv2 import numpy as np -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -CenteringType = Literal["face", "head", "legacy"] +CenteringType = T.Literal["face", "head", "legacy"] _MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748], [0.300643, 0.034489], [0.403270, 0.077391], [0.596729, 0.077391], @@ -65,10 +57,10 @@ _MEAN_FACE_3D = np.array([[4.056931, -11.432347, 1.636229], # 8 chin LL [0.0, -8.601736, 6.097667], # 45 mouth bottom C [0.589441, -8.443925, 6.109526]]) # 44 mouth bottom L -_EXTRACT_RATIOS = dict(legacy=0.375, face=0.5, head=0.625) +_EXTRACT_RATIOS = {"legacy": 0.375, "face": 0.5, "head": 0.625} -def get_matrix_scaling(matrix: np.ndarray) -> Tuple[int, int]: +def get_matrix_scaling(matrix: np.ndarray) -> tuple[int, int]: """ Given a matrix, return the cv2 Interpolation method and inverse interpolation method for applying the matrix on an image. @@ -213,6 +205,149 @@ def get_centered_size(source_centering: CenteringType, return retval +class PoseEstimate(): + """ Estimates pose from a generic 3D head model for the given 2D face landmarks. + + Parameters + ---------- + landmarks: :class:`numpy.ndarry` + The original 68 point landmarks aligned to 0.0 - 1.0 range + + References + ---------- + Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/ + 3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp + """ + def __init__(self, landmarks: np.ndarray) -> None: + self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion + self._xyz_2d: np.ndarray | None = None + + self._camera_matrix = self._get_camera_matrix() + self._rotation, self._translation = self._solve_pnp(landmarks) + self._offset = self._get_offset() + self._pitch_yaw_roll: tuple[float, float, float] = (0, 0, 0) + + @property + def xyz_2d(self) -> np.ndarray: + """ :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a + constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """ + if self._xyz_2d is None: + xyz = cv2.projectPoints(np.array([[6., 0., -2.3], + [0., 6., -2.3], + [0., 0., 3.7]]).astype("float32"), + self._rotation, + self._translation, + self._camera_matrix, + self._distortion_coefficients)[0].squeeze() + self._xyz_2d = xyz - self._offset["head"] + return self._xyz_2d + + @property + def offset(self) -> dict[CenteringType, np.ndarray]: + """ dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a + from the center of the face (between the eyes) or center of the head (middle of skull) + rather than the nose area. """ + return self._offset + + @property + def pitch(self) -> float: + """ float: The pitch of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[0] + + @property + def yaw(self) -> float: + """ float: The yaw of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[1] + + @property + def roll(self) -> float: + """ float: The roll of the aligned face in eular angles """ + if not any(self._pitch_yaw_roll): + self._get_pitch_yaw_roll() + return self._pitch_yaw_roll[2] + + def _get_pitch_yaw_roll(self) -> None: + """ Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """ + proj_matrix = np.zeros((3, 4), dtype="float32") + proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0] + euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1] + self._pitch_yaw_roll = T.cast(tuple[float, float, float], tuple(euler.squeeze())) + logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore + + @classmethod + def _get_camera_matrix(cls) -> np.ndarray: + """ Obtain an estimate of the camera matrix based off the original frame dimensions. + + Returns + ------- + :class:`numpy.ndarray` + An estimated camera matrix + """ + focal_length = 4 + camera_matrix = np.array([[focal_length, 0, 0.5], + [0, focal_length, 0.5], + [0, 0, 1]], dtype="double") + logger.trace("camera_matrix: %s", camera_matrix) # type: ignore + return camera_matrix + + def _solve_pnp(self, landmarks: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ Solve the Perspective-n-Point for the given landmarks. + + Takes 2D landmarks in world space and estimates the rotation and translation vectors + in 3D space. + + Parameters + ---------- + landmarks: :class:`numpy.ndarry` + The original 68 point landmark co-ordinates relating to the original frame + + Returns + ------- + rotation: :class:`numpy.ndarray` + The solved rotation vector + translation: :class:`numpy.ndarray` + The solved translation vector + """ + points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34, + 35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]] + _, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D, + points, + self._camera_matrix, + self._distortion_coefficients, + flags=cv2.SOLVEPNP_ITERATIVE) + logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore + points, rotation, translation) + return rotation, translation + + def _get_offset(self) -> dict[CenteringType, np.ndarray]: + """ Obtain the offset between the original center of the extracted face to the new center + of the head in 2D space. + + Returns + ------- + :class:`numpy.ndarray` + The x, y offset of the new center from the old center. + """ + offset: dict[CenteringType, np.ndarray] = {"legacy": np.array([0.0, 0.0])} + points: dict[T.Literal["face", "head"], tuple[float, ...]] = {"head": (0.0, 0.0, -2.3), + "face": (0.0, -1.5, 4.2)} + + for key, pnts in points.items(): + center = cv2.projectPoints(np.array([pnts]).astype("float32"), + self._rotation, + self._translation, + self._camera_matrix, + self._distortion_coefficients)[0].squeeze() + logger.trace("center %s: %s", key, center) # type: ignore + offset[key] = center - (0.5, 0.5) + logger.trace("offset: %s", offset) # type: ignore + return offset + + @dataclass class _FaceCache: # pylint:disable=too-many-instance-attributes """ Cache for storing items related to a single aligned face. @@ -251,19 +386,19 @@ class _FaceCache: # pylint:disable=too-many-instance-attributes cropped_slices: dict, optional The slices for an input full head image and output cropped image. Default: `{}` """ - pose: Optional["PoseEstimate"] = None - original_roi: Optional[np.ndarray] = None - landmarks: Optional[np.ndarray] = None - landmarks_normalized: Optional[np.ndarray] = None + pose: PoseEstimate | None = None + original_roi: np.ndarray | None = None + landmarks: np.ndarray | None = None + landmarks_normalized: np.ndarray | None = None average_distance: float = 0.0 relative_eye_mouth_position: float = 0.0 - adjusted_matrix: Optional[np.ndarray] = None - interpolators: Tuple[int, int] = (0, 0) - cropped_roi: Dict[CenteringType, np.ndarray] = field(default_factory=dict) - cropped_slices: Dict[CenteringType, Dict[Literal["in", "out"], - Tuple[slice, slice]]] = field(default_factory=dict) + adjusted_matrix: np.ndarray | None = None + interpolators: tuple[int, int] = (0, 0) + cropped_roi: dict[CenteringType, np.ndarray] = field(default_factory=dict) + cropped_slices: dict[CenteringType, dict[T.Literal["in", "out"], + tuple[slice, slice]]] = field(default_factory=dict) - _locks: Dict[str, Lock] = field(default_factory=dict) + _locks: dict[str, Lock] = field(default_factory=dict) def __post_init__(self): """ Initialize the locks for the class parameters """ @@ -322,11 +457,11 @@ class AlignedFace(): """ def __init__(self, landmarks: np.ndarray, - image: Optional[np.ndarray] = None, + image: np.ndarray | None = None, centering: CenteringType = "face", size: int = 64, coverage_ratio: float = 1.0, - dtype: Optional[str] = None, + dtype: str | None = None, is_aligned: bool = False, is_legacy: bool = False) -> None: logger.trace("Initializing: %s (image shape: %s, centering: '%s', " # type: ignore @@ -340,9 +475,9 @@ class AlignedFace(): self._dtype = dtype self._is_aligned = is_aligned self._source_centering: CenteringType = "legacy" if is_legacy and is_aligned else "head" - self._matrices = dict(legacy=_umeyama(landmarks[17:], _MEAN_FACE, True)[0:2], - face=np.array([]), - head=np.array([])) + self._matrices = {"legacy": _umeyama(landmarks[17:], _MEAN_FACE, True)[0:2], + "face": np.array([]), + "head": np.array([])} self._padding = self._padding_from_coverage(size, coverage_ratio) self._cache = _FaceCache() @@ -353,7 +488,7 @@ class AlignedFace(): self._face if self._face is None else self._face.shape) @property - def centering(self) -> Literal["legacy", "head", "face"]: + def centering(self) -> T.Literal["legacy", "head", "face"]: """ str: The centering of the Aligned Face. One of `"legacy"`, `"head"`, `"face"`. """ return self._centering @@ -382,7 +517,7 @@ class AlignedFace(): return self._matrices[self._centering] @property - def pose(self) -> "PoseEstimate": + def pose(self) -> PoseEstimate: """ :class:`lib.align.PoseEstimate`: The estimated pose in 3D space. """ with self._cache.lock("pose"): if self._cache.pose is None: @@ -405,7 +540,7 @@ class AlignedFace(): return self._cache.adjusted_matrix @property - def face(self) -> Optional[np.ndarray]: + def face(self) -> np.ndarray | None: """ :class:`numpy.ndarray`: The aligned face at the given :attr:`size` at the specified :attr:`coverage` in the given :attr:`dtype`. If an :attr:`image` has not been provided then an the attribute will return ``None``. """ @@ -450,7 +585,7 @@ class AlignedFace(): return self._cache.landmarks_normalized @property - def interpolators(self) -> Tuple[int, int]: + def interpolators(self) -> tuple[int, int]: """ tuple: (`interpolator` and `reverse interpolator`) for the :attr:`adjusted matrix`. """ with self._cache.lock("interpolators"): if not any(self._cache.interpolators): @@ -487,7 +622,7 @@ class AlignedFace(): return self._cache.relative_eye_mouth_position @classmethod - def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> Dict[CenteringType, int]: + def _padding_from_coverage(cls, size: int, coverage_ratio: float) -> dict[CenteringType, int]: """ Return the image padding for a face from coverage_ratio set against a pre-padded training image. @@ -504,7 +639,7 @@ class AlignedFace(): The padding required, in pixels for 'head', 'face' and 'legacy' face types """ retval = {_type: round((size * (coverage_ratio - (1 - _EXTRACT_RATIOS[_type]))) / 2) - for _type in get_args(Literal["legacy", "face", "head"])} + for _type in T.get_args(T.Literal["legacy", "face", "head"])} logger.trace(retval) # type: ignore return retval @@ -532,7 +667,7 @@ class AlignedFace(): invert, points, retval) return retval - def extract_face(self, image: Optional[np.ndarray]) -> Optional[np.ndarray]: + def extract_face(self, image: np.ndarray | None) -> np.ndarray | None: """ Extract the face from a source image and populate :attr:`face`. If an image is not provided then ``None`` is returned. @@ -605,7 +740,7 @@ class AlignedFace(): def _get_cropped_slices(self, image_size: int, target_size: int, - ) -> Dict[Literal["in", "out"], Tuple[slice, slice]]: + ) -> dict[T.Literal["in", "out"], tuple[slice, slice]]: """ Obtain the slices to turn a full head extract into an alternatively centered extract. Parameters @@ -676,149 +811,6 @@ class AlignedFace(): return self._cache.cropped_roi[centering] -class PoseEstimate(): - """ Estimates pose from a generic 3D head model for the given 2D face landmarks. - - Parameters - ---------- - landmarks: :class:`numpy.ndarry` - The original 68 point landmarks aligned to 0.0 - 1.0 range - - References - ---------- - Head Pose Estimation using OpenCV and Dlib - https://www.learnopencv.com/tag/solvepnp/ - 3D Model points - http://aifi.isr.uc.pt/Downloads/OpenGL/glAnthropometric3DModel.cpp - """ - def __init__(self, landmarks: np.ndarray) -> None: - self._distortion_coefficients = np.zeros((4, 1)) # Assuming no lens distortion - self._xyz_2d: Optional[np.ndarray] = None - - self._camera_matrix = self._get_camera_matrix() - self._rotation, self._translation = self._solve_pnp(landmarks) - self._offset = self._get_offset() - self._pitch_yaw_roll: Tuple[float, float, float] = (0, 0, 0) - - @property - def xyz_2d(self) -> np.ndarray: - """ :class:`numpy.ndarray` projected (x, y) coordinates for each x, y, z point at a - constant distance from adjusted center of the skull (0.5, 0.5) in the 2D space. """ - if self._xyz_2d is None: - xyz = cv2.projectPoints(np.array([[6., 0., -2.3], - [0., 6., -2.3], - [0., 0., 3.7]]).astype("float32"), - self._rotation, - self._translation, - self._camera_matrix, - self._distortion_coefficients)[0].squeeze() - self._xyz_2d = xyz - self._offset["head"] - return self._xyz_2d - - @property - def offset(self) -> Dict[CenteringType, np.ndarray]: - """ dict: The amount to offset a standard 0.0 - 1.0 umeyama transformation matrix for a - from the center of the face (between the eyes) or center of the head (middle of skull) - rather than the nose area. """ - return self._offset - - @property - def pitch(self) -> float: - """ float: The pitch of the aligned face in eular angles """ - if not any(self._pitch_yaw_roll): - self._get_pitch_yaw_roll() - return self._pitch_yaw_roll[0] - - @property - def yaw(self) -> float: - """ float: The yaw of the aligned face in eular angles """ - if not any(self._pitch_yaw_roll): - self._get_pitch_yaw_roll() - return self._pitch_yaw_roll[1] - - @property - def roll(self) -> float: - """ float: The roll of the aligned face in eular angles """ - if not any(self._pitch_yaw_roll): - self._get_pitch_yaw_roll() - return self._pitch_yaw_roll[2] - - def _get_pitch_yaw_roll(self) -> None: - """ Obtain the yaw, roll and pitch from the :attr:`_rotation` in eular angles. """ - proj_matrix = np.zeros((3, 4), dtype="float32") - proj_matrix[:3, :3] = cv2.Rodrigues(self._rotation)[0] - euler = cv2.decomposeProjectionMatrix(proj_matrix)[-1] - self._pitch_yaw_roll = cast(Tuple[float, float, float], tuple(euler.squeeze())) - logger.trace("yaw_pitch: %s", self._pitch_yaw_roll) # type: ignore - - @classmethod - def _get_camera_matrix(cls) -> np.ndarray: - """ Obtain an estimate of the camera matrix based off the original frame dimensions. - - Returns - ------- - :class:`numpy.ndarray` - An estimated camera matrix - """ - focal_length = 4 - camera_matrix = np.array([[focal_length, 0, 0.5], - [0, focal_length, 0.5], - [0, 0, 1]], dtype="double") - logger.trace("camera_matrix: %s", camera_matrix) # type: ignore - return camera_matrix - - def _solve_pnp(self, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ Solve the Perspective-n-Point for the given landmarks. - - Takes 2D landmarks in world space and estimates the rotation and translation vectors - in 3D space. - - Parameters - ---------- - landmarks: :class:`numpy.ndarry` - The original 68 point landmark co-ordinates relating to the original frame - - Returns - ------- - rotation: :class:`numpy.ndarray` - The solved rotation vector - translation: :class:`numpy.ndarray` - The solved translation vector - """ - points = landmarks[[6, 7, 8, 9, 10, 17, 21, 22, 26, 31, 32, 33, 34, - 35, 36, 39, 42, 45, 48, 50, 51, 52, 54, 56, 57, 58]] - _, rotation, translation = cv2.solvePnP(_MEAN_FACE_3D, - points, - self._camera_matrix, - self._distortion_coefficients, - flags=cv2.SOLVEPNP_ITERATIVE) - logger.trace("points: %s, rotation: %s, translation: %s", # type: ignore - points, rotation, translation) - return rotation, translation - - def _get_offset(self) -> Dict[CenteringType, np.ndarray]: - """ Obtain the offset between the original center of the extracted face to the new center - of the head in 2D space. - - Returns - ------- - :class:`numpy.ndarray` - The x, y offset of the new center from the old center. - """ - offset: Dict[CenteringType, np.ndarray] = dict(legacy=np.array([0.0, 0.0])) - points: Dict[Literal["face", "head"], Tuple[float, ...]] = dict(head=(0.0, 0.0, -2.3), - face=(0.0, -1.5, 4.2)) - - for key, pnts in points.items(): - center = cv2.projectPoints(np.array([pnts]).astype("float32"), - self._rotation, - self._translation, - self._camera_matrix, - self._distortion_coefficients)[0].squeeze() - logger.trace("center %s: %s", key, center) # type: ignore - offset[key] = center - (0.5, 0.5) - logger.trace("offset: %s", offset) # type: ignore - return offset - - def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray: """Estimate N-D similarity transformation with or without scaling. @@ -866,24 +858,24 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) if np.linalg.det(A) < 0: d[dim - 1] = -1 - T = np.eye(dim + 1, dtype=np.double) + retval = np.eye(dim + 1, dtype=np.double) U, S, V = np.linalg.svd(A) # Eq. (40) and (43). rank = np.linalg.matrix_rank(A) if rank == 0: - return np.nan * T + return np.nan * retval if rank == dim - 1: if np.linalg.det(U) * np.linalg.det(V) > 0: - T[:dim, :dim] = U @ V + retval[:dim, :dim] = U @ V else: s = d[dim - 1] d[dim - 1] = -1 - T[:dim, :dim] = U @ np.diag(d) @ V + retval[:dim, :dim] = U @ np.diag(d) @ V d[dim - 1] = s else: - T[:dim, :dim] = U @ np.diag(d) @ V + retval[:dim, :dim] = U @ np.diag(d) @ V if estimate_scale: # Eq. (41) and (42). @@ -891,7 +883,7 @@ def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) else: scale = 1.0 - T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) - T[:dim, :dim] *= scale + retval[:dim, dim] = dst_mean - scale * (retval[:dim, :dim] @ src_mean.T) + retval[:dim, :dim] *= scale - return T + return retval diff --git a/lib/align/alignments.py b/lib/align/alignments.py index cc1c52ea..33ff2f61 100644 --- a/lib/align/alignments.py +++ b/lib/align/alignments.py @@ -1,24 +1,19 @@ #!/usr/bin/env python3 """ Alignments file functions for reading, writing and manipulating the data stored in a serialized alignments file. """ - +from __future__ import annotations import logging import os -import sys +import typing as T from datetime import datetime -from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union import numpy as np from lib.serializer import get_serializer, get_serializer_from_filename from lib.utils import FaceswapError -if sys.version_info < (3, 8): - from typing_extensions import TypedDict -else: - from typing import TypedDict - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from .aligned_face import CenteringType logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -35,49 +30,49 @@ _VERSION = 2.3 # TODO Convert these to Dataclasses -class MaskAlignmentsFileDict(TypedDict): +class MaskAlignmentsFileDict(T.TypedDict): """ Typed Dictionary for storing Masks. """ mask: bytes - affine_matrix: Union[List[float], np.ndarray] + affine_matrix: list[float] | np.ndarray interpolator: int stored_size: int - stored_centering: "CenteringType" + stored_centering: CenteringType -class PNGHeaderAlignmentsDict(TypedDict): +class PNGHeaderAlignmentsDict(T.TypedDict): """ Base Dictionary for storing a single faces' Alignment Information in Alignments files and PNG Headers. """ x: int y: int w: int h: int - landmarks_xy: Union[List[float], np.ndarray] - mask: Dict[str, MaskAlignmentsFileDict] - identity: Dict[str, List[float]] + landmarks_xy: list[float] | np.ndarray + mask: dict[str, MaskAlignmentsFileDict] + identity: dict[str, list[float]] class AlignmentFileDict(PNGHeaderAlignmentsDict): """ Typed Dictionary for storing a single faces' Alignment Information in alignments files. """ - thumb: Optional[np.ndarray] + thumb: np.ndarray | None -class PNGHeaderSourceDict(TypedDict): +class PNGHeaderSourceDict(T.TypedDict): """ Dictionary for storing additional meta information in PNG headers """ alignments_version: float original_filename: str face_index: int source_filename: str source_is_video: bool - source_frame_dims: Optional[Tuple[int, int]] + source_frame_dims: tuple[int, int] | None -class AlignmentDict(TypedDict): +class AlignmentDict(T.TypedDict): """ Dictionary for holding all of the alignment information within a single alignment file """ - faces: List[AlignmentFileDict] - video_meta: Dict[str, Union[float, int]] + faces: list[AlignmentFileDict] + video_meta: dict[str, float | int] -class PNGHeaderDict(TypedDict): +class PNGHeaderDict(T.TypedDict): """ Dictionary for storing all alignment and meta information in PNG Headers """ alignments: PNGHeaderAlignmentsDict source: PNGHeaderSourceDict @@ -135,7 +130,7 @@ class Alignments(): return self._io.file @property - def data(self) -> Dict[str, AlignmentDict]: + def data(self) -> dict[str, AlignmentDict]: """ dict: The loaded alignments :attr:`file` in dictionary form. """ return self._data @@ -146,7 +141,7 @@ class Alignments(): return self._io.have_alignments_file @property - def hashes_to_frame(self) -> Dict[str, Dict[str, int]]: + def hashes_to_frame(self) -> dict[str, dict[str, int]]: """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame that the hash corresponds to. @@ -158,7 +153,7 @@ class Alignments(): return self._legacy.hashes_to_frame @property - def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]: + def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]: """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash corresponds to. The structure of the dictionary is: @@ -170,10 +165,10 @@ class Alignments(): return self._legacy.hashes_to_alignment @property - def mask_summary(self) -> Dict[str, int]: + def mask_summary(self) -> dict[str, int]: """ dict: The mask type names stored in the alignments :attr:`data` as key with the number of faces which possess the mask type as value. """ - masks: Dict[str, int] = {} + masks: dict[str, int] = {} for val in self._data.values(): for face in val["faces"]: if face.get("mask", None) is None: @@ -183,21 +178,20 @@ class Alignments(): return masks @property - def video_meta_data(self) -> Dict[str, Optional[Union[List[int], List[float]]]]: + def video_meta_data(self) -> dict[str, list[int] | list[float] | None]: """ dict: The frame meta data stored in the alignments file. If data does not exist in the alignments file then ``None`` is returned for each Key """ - retval: Dict[str, Optional[Union[List[int], - List[float]]]] = dict(pts_time=None, keyframes=None) - pts_time: List[float] = [] - keyframes: List[int] = [] + retval: dict[str, list[int] | list[float] | None] = {"pts_time": None, "keyframes": None} + pts_time: list[float] = [] + keyframes: list[int] = [] for idx, key in enumerate(sorted(self.data)): if not self.data[key].get("video_meta", {}): return retval meta = self.data[key]["video_meta"] - pts_time.append(cast(float, meta["pts_time"])) + pts_time.append(T.cast(float, meta["pts_time"])) if meta["keyframe"]: keyframes.append(idx) - retval = dict(pts_time=pts_time, keyframes=keyframes) + retval = {"pts_time": pts_time, "keyframes": keyframes} return retval @property @@ -211,7 +205,7 @@ class Alignments(): """ float: The alignments file version number. """ return self._io.version - def _load(self) -> Dict[str, AlignmentDict]: + def _load(self) -> dict[str, AlignmentDict]: """ Load the alignments data from the serialized alignments :attr:`file`. Populates :attr:`_version` with the alignment file's loaded version as well as returning @@ -238,7 +232,7 @@ class Alignments(): """ return self._io.backup() - def save_video_meta_data(self, pts_time: List[float], keyframes: List[int]) -> None: + def save_video_meta_data(self, pts_time: list[float], keyframes: list[int]) -> None: """ Save video meta data to the alignments file. If the alignments file does not have an entry for every frame (e.g. if Extract Every N @@ -262,10 +256,10 @@ class Alignments(): logger.info("Saving video meta information to Alignments file") for idx, pts in enumerate(pts_time): - meta: Dict[str, Union[float, int]] = dict(pts_time=pts, keyframe=idx in keyframes) + meta: dict[str, float | int] = {"pts_time": pts, "keyframe": idx in keyframes} key = f"{basename}_{idx + 1:06d}.png" if key not in self.data: - self.data[key] = dict(video_meta=meta, faces=[]) + self.data[key] = {"video_meta": meta, "faces": []} else: self.data[key]["video_meta"] = meta @@ -285,8 +279,8 @@ class Alignments(): self._io.save() @classmethod - def _pad_leading_frames(cls, pts_time: List[float], keyframes: List[int]) -> Tuple[List[float], - List[int]]: + def _pad_leading_frames(cls, pts_time: list[float], keyframes: list[int]) -> tuple[list[float], + list[int]]: """ Calculate the number of frames to pad the video by when the first frame is not a key frame. @@ -310,7 +304,7 @@ class Alignments(): """ start_pts = pts_time[0] logger.debug("Video not cut on keyframe. Start pts: %s", start_pts) - gaps: List[float] = [] + gaps: list[float] = [] prev_time = None for item in pts_time: if prev_time is not None: @@ -360,7 +354,7 @@ class Alignments(): ``True`` if the given frame_name exists within the alignments :attr:`data` and has at least 1 face associated with it, otherwise ``False`` """ - frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) retval = bool(frame_data.get("faces", [])) logger.trace("'%s': %s", frame_name, retval) # type:ignore return retval @@ -384,7 +378,7 @@ class Alignments(): if not frame_name: retval = False else: - frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) retval = bool(len(frame_data.get("faces", [])) > 1) logger.trace("'%s': %s", frame_name, retval) # type:ignore return retval @@ -414,7 +408,7 @@ class Alignments(): return retval # << DATA >> # - def get_faces_in_frame(self, frame_name: str) -> List[AlignmentFileDict]: + def get_faces_in_frame(self, frame_name: str) -> list[AlignmentFileDict]: """ Obtain the faces from :attr:`data` associated with a given frame_name. Parameters @@ -429,8 +423,8 @@ class Alignments(): The list of face dictionaries that appear within the requested frame_name """ logger.trace("Getting faces for frame_name: '%s'", frame_name) # type:ignore - frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) - return frame_data.get("faces", cast(List[AlignmentFileDict], [])) + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) + return frame_data.get("faces", T.cast(list[AlignmentFileDict], [])) def _count_faces_in_frame(self, frame_name: str) -> int: """ Return number of faces that appear within :attr:`data` for the given frame_name. @@ -446,7 +440,7 @@ class Alignments(): int The number of faces that appear in the given frame_name """ - frame_data = self._data.get(frame_name, cast(AlignmentDict, {})) + frame_data = self._data.get(frame_name, T.cast(AlignmentDict, {})) retval = len(frame_data.get("faces", [])) logger.trace(retval) # type:ignore return retval @@ -497,7 +491,7 @@ class Alignments(): """ logger.debug("Adding face to frame_name: '%s'", frame_name) if frame_name not in self._data: - self._data[frame_name] = dict(faces=[], video_meta={}) + self._data[frame_name] = {"faces": [], "video_meta": {}} self._data[frame_name]["faces"].append(face) retval = self._count_faces_in_frame(frame_name) - 1 logger.debug("Returning new face index: %s", retval) @@ -520,7 +514,7 @@ class Alignments(): logger.debug("Updating face %s for frame_name '%s'", face_index, frame_name) self._data[frame_name]["faces"][face_index] = face - def filter_faces(self, filter_dict: Dict[str, List[int]], filter_out: bool = False) -> None: + def filter_faces(self, filter_dict: dict[str, list[int]], filter_out: bool = False) -> None: """ Remove faces from :attr:`data` based on a given filter list. Parameters @@ -549,7 +543,7 @@ class Alignments(): del frame_data["faces"][face_idx] # << GENERATORS >> # - def yield_faces(self) -> Generator[Tuple[str, List[AlignmentFileDict], int, str], None, None]: + def yield_faces(self) -> Generator[tuple[str, list[AlignmentFileDict], int, str], None, None]: """ Generator to obtain all faces with meta information from :attr:`data`. The results are yielded by frame. @@ -715,7 +709,7 @@ class _IO(): logger.info("Updating alignments file to version %s", self._version) self.save() - def load(self) -> Dict[str, AlignmentDict]: + def load(self) -> dict[str, AlignmentDict]: """ Load the alignments data from the serialized alignments :attr:`file`. Populates :attr:`_version` with the alignment file's loaded version as well as returning @@ -732,7 +726,7 @@ class _IO(): logger.info("Reading alignments from: '%s'", self._file) data = self._serializer.load(self._file) - meta = data.get("__meta__", dict(version=1.0)) + meta = data.get("__meta__", {"version": 1.0}) self._version = meta["version"] data = data.get("__data__", data) logger.debug("Loaded alignments") @@ -743,8 +737,8 @@ class _IO(): the location :attr:`file`. """ logger.debug("Saving alignments") logger.info("Writing alignments to: '%s'", self._file) - data = dict(__meta__=dict(version=self._version), - __data__=self._alignments.data) + data = {"__meta__": {"version": self._version}, + "__data__": self._alignments.data} self._serializer.save(self._file, data) logger.debug("Saved alignments") @@ -928,7 +922,7 @@ class _FileStructure(_Updater): for key, val in self._alignments.data.items(): if not isinstance(val, list): continue - self._alignments.data[key] = dict(faces=val) + self._alignments.data[key] = {"faces": val} updated += 1 return updated @@ -1078,11 +1072,11 @@ class _Legacy(): """ def __init__(self, alignments: Alignments) -> None: self._alignments = alignments - self._hashes_to_frame: Dict[str, Dict[str, int]] = {} - self._hashes_to_alignment: Dict[str, AlignmentFileDict] = {} + self._hashes_to_frame: dict[str, dict[str, int]] = {} + self._hashes_to_alignment: dict[str, AlignmentFileDict] = {} @property - def hashes_to_frame(self) -> Dict[str, Dict[str, int]]: + def hashes_to_frame(self) -> dict[str, dict[str, int]]: """ dict: The SHA1 hash of the face mapped to the frame(s) and face index within the frame that the hash corresponds to. The structure of the dictionary is: @@ -1105,7 +1099,7 @@ class _Legacy(): return self._hashes_to_frame @property - def hashes_to_alignment(self) -> Dict[str, AlignmentFileDict]: + def hashes_to_alignment(self) -> dict[str, AlignmentFileDict]: """ dict: The SHA1 hash of the face mapped to the alignment for the face that the hash corresponds to. The structure of the dictionary is: diff --git a/lib/align/detected_face.py b/lib/align/detected_face.py index 74ca9a95..bf48dfb0 100644 --- a/lib/align/detected_face.py +++ b/lib/align/detected_face.py @@ -1,12 +1,11 @@ #!/usr/bin python3 """ Face and landmarks detection for faceswap.py """ - +from __future__ import annotations import logging -import sys import os +import typing as T from hashlib import sha1 -from typing import cast, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from zlib import compress, decompress import cv2 @@ -18,14 +17,10 @@ from .alignments import (Alignments, AlignmentFileDict, MaskAlignmentsFileDict, PNGHeaderAlignmentsDict, PNGHeaderDict, PNGHeaderSourceDict) from . import AlignedFace, get_adjusted_center, get_centered_size -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Callable from .aligned_face import CenteringType -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -85,14 +80,14 @@ class DetectedFace(): dict of {**name** (`str`): :class:`Mask`}. """ def __init__(self, - image: Optional[np.ndarray] = None, - left: Optional[int] = None, - width: Optional[int] = None, - top: Optional[int] = None, - height: Optional[int] = None, - landmarks_xy: Optional[np.ndarray] = None, - mask: Optional[Dict[str, "Mask"]] = None, - filename: Optional[str] = None) -> None: + image: np.ndarray | None = None, + left: int | None = None, + width: int | None = None, + top: int | None = None, + height: int | None = None, + landmarks_xy: np.ndarray | None = None, + mask: dict[str, "Mask"] | None = None, + filename: str | None = None) -> None: logger.trace("Initializing %s: (image: %s, left: %s, width: %s, top: %s, " # type: ignore "height: %s, landmarks_xy: %s, mask: %s, filename: %s)", self.__class__.__name__, @@ -104,12 +99,12 @@ class DetectedFace(): self.top = top self.height = height self._landmarks_xy = landmarks_xy - self._identity: Dict[str, np.ndarray] = {} - self.thumbnail: Optional[np.ndarray] = None + self._identity: dict[str, np.ndarray] = {} + self.thumbnail: np.ndarray | None = None self.mask = {} if mask is None else mask - self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None + self._training_masks: tuple[bytes, tuple[int, int, int]] | None = None - self._aligned: Optional[AlignedFace] = None + self._aligned: AlignedFace | None = None logger.trace("Initialized %s", self.__class__.__name__) # type: ignore @property @@ -137,7 +132,7 @@ class DetectedFace(): return self.top + self.height @property - def identity(self) -> Dict[str, np.ndarray]: + def identity(self) -> dict[str, np.ndarray]: """ dict: Identity mechanism as key, identity embedding as value. """ return self._identity @@ -147,7 +142,7 @@ class DetectedFace(): affine_matrix: np.ndarray, interpolator: int, storage_size: int = 128, - storage_centering: "CenteringType" = "face") -> None: + storage_centering: CenteringType = "face") -> None: """ Add a :class:`Mask` to this detected face The mask should be the original output from :mod:`plugins.extract.mask` @@ -209,7 +204,7 @@ class DetectedFace(): self._identity[name] = embedding def get_landmark_mask(self, - area: Literal["eye", "face", "mouth"], + area: T.Literal["eye", "face", "mouth"], blur_kernel: int, dilation: int) -> np.ndarray: """ Add a :class:`LandmarksMask` to this detected face @@ -235,7 +230,7 @@ class DetectedFace(): """ # TODO Face mask generation from landmarks logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore - areas = dict(mouth=[slice(48, 60)], eye=[slice(36, 42), slice(42, 48)]) + areas = {"mouth": [slice(48, 60)], "eye": [slice(36, 42), slice(42, 48)]} points = [self.aligned.landmarks[zone] for zone in areas[area]] @@ -250,7 +245,7 @@ class DetectedFace(): return lmmask.mask def store_training_masks(self, - masks: List[Optional[np.ndarray]], + masks: list[np.ndarray | None], delete_masks: bool = False) -> None: """ Concatenate and compress the given training masks and store for retrieval. @@ -273,7 +268,7 @@ class DetectedFace(): combined = np.concatenate(valid, axis=-1) self._training_masks = (compress(combined), combined.shape) - def get_training_masks(self) -> Optional[np.ndarray]: + def get_training_masks(self) -> np.ndarray | None: """ Obtain the decompressed combined training masks. Returns @@ -312,7 +307,7 @@ class DetectedFace(): return alignment def from_alignment(self, alignment: AlignmentFileDict, - image: Optional[np.ndarray] = None, with_thumb: bool = False) -> None: + image: np.ndarray | None = None, with_thumb: bool = False) -> None: """ Set the attributes of this class from an alignments file and optionally load the face into the ``image`` attribute. @@ -342,7 +337,7 @@ class DetectedFace(): landmarks = alignment["landmarks_xy"] if not isinstance(landmarks, np.ndarray): landmarks = np.array(landmarks, dtype="float32") - self._identity = {cast(Literal["vggface2"], k): np.array(v, dtype="float32") + self._identity = {T.cast(T.Literal["vggface2"], k): np.array(v, dtype="float32") for k, v in alignment.get("identity", {}).items()} self._landmarks_xy = landmarks.copy() @@ -403,7 +398,7 @@ class DetectedFace(): self._identity = {} for key, val in alignment.get("identity", {}).items(): assert key in ["vggface2"] - self._identity[cast(Literal["vggface2"], key)] = np.array(val, dtype="float32") + self._identity[T.cast(T.Literal["vggface2"], key)] = np.array(val, dtype="float32") logger.trace("Created from png exif header: (left: %s, width: %s, top: %s " # type: ignore " height: %s, landmarks: %s, mask: %s, identity: %s)", self.left, self.width, self.top, self.height, self.landmarks_xy, self.mask, @@ -417,10 +412,10 @@ class DetectedFace(): # <<< Aligned Face methods and properties >>> # def load_aligned(self, - image: Optional[np.ndarray], + image: np.ndarray | None, size: int = 256, - dtype: Optional[str] = None, - centering: "CenteringType" = "head", + dtype: str | None = None, + centering: CenteringType = "head", coverage_ratio: float = 1.0, force: bool = False, is_aligned: bool = False, @@ -507,22 +502,22 @@ class Mask(): """ def __init__(self, storage_size: int = 128, - storage_centering: "CenteringType" = "face") -> None: + storage_centering: CenteringType = "face") -> None: logger.trace("Initializing: %s (storage_size: %s, storage_centering: %s)", # type: ignore self.__class__.__name__, storage_size, storage_centering) self.stored_size = storage_size self.stored_centering = storage_centering - self._mask: Optional[bytes] = None - self._affine_matrix: Optional[np.ndarray] = None - self._interpolator: Optional[int] = None + self._mask: bytes | None = None + self._affine_matrix: np.ndarray | None = None + self._interpolator: int | None = None - self._blur_type: Optional[Literal["gaussian", "normalized"]] = None + self._blur_type: T.Literal["gaussian", "normalized"] | None = None self._blur_passes: int = 0 - self._blur_kernel: Union[float, int] = 0 + self._blur_kernel: float | int = 0 self._threshold = 0.0 self._sub_crop_size = 0 - self._sub_crop_slices: Dict[Literal["in", "out"], List[slice]] = {} + self._sub_crop_slices: dict[T.Literal["in", "out"], list[slice]] = {} self.set_blur_and_threshold() logger.trace("Initialized: %s", self.__class__.__name__) # type: ignore @@ -648,7 +643,7 @@ class Mask(): def set_blur_and_threshold(self, blur_kernel: int = 0, - blur_type: Optional[Literal["gaussian", "normalized"]] = "gaussian", + blur_type: T.Literal["gaussian", "normalized"] | None = "gaussian", blur_passes: int = 1, threshold: int = 0) -> None: """ Set the internal blur kernel and threshold amount for returned masks @@ -679,7 +674,7 @@ class Mask(): def set_sub_crop(self, source_offset: np.ndarray, target_offset: np.ndarray, - centering: "CenteringType", + centering: CenteringType, coverage_ratio: float = 1.0) -> None: """ Set the internal crop area of the mask to be returned. @@ -831,9 +826,9 @@ class LandmarksMask(Mask): The amount of dilation to apply to the mask. `0` for none. Default: `0` """ def __init__(self, - points: List[np.ndarray], + points: list[np.ndarray], storage_size: int = 128, - storage_centering: "CenteringType" = "face", + storage_centering: CenteringType = "face", dilation: int = 0) -> None: super().__init__(storage_size=storage_size, storage_centering=storage_centering) self._points = points @@ -907,9 +902,9 @@ class BlurMask(): # pylint:disable=too-few-public-methods (128, 128, 1) """ def __init__(self, - blur_type: Literal["gaussian", "normalized"], + blur_type: T.Literal["gaussian", "normalized"], mask: np.ndarray, - kernel: Union[int, float], + kernel: int | float, is_ratio: bool = False, passes: int = 1) -> None: logger.trace("Initializing %s: (blur_type: '%s', mask_shape: %s, " # type: ignore @@ -943,33 +938,30 @@ class BlurMask(): # pylint:disable=too-few-public-methods def _multipass_factor(self) -> float: """ For multiple passes the kernel must be scaled down. This value is different for box filter and gaussian """ - factor = dict(gaussian=0.8, normalized=0.5) + factor = {"gaussian": 0.8, "normalized": 0.5} return factor[self._blur_type] @property - def _sigma(self) -> Literal[0]: + def _sigma(self) -> T.Literal[0]: """ int: The Sigma for Gaussian Blur. Returns 0 to force calculation from kernel size. """ return 0 @property - def _func_mapping(self) -> Dict[Literal["gaussian", "normalized"], Callable]: + def _func_mapping(self) -> dict[T.Literal["gaussian", "normalized"], Callable]: """ dict: :attr:`_blur_type` mapped to cv2 Function name. """ - return dict(gaussian=cv2.GaussianBlur, # pylint: disable = no-member - normalized=cv2.blur) # pylint: disable = no-member + return {"gaussian": cv2.GaussianBlur, "normalized": cv2.blur} @property - def _kwarg_requirements(self) -> Dict[Literal["gaussian", "normalized"], List[str]]: + def _kwarg_requirements(self) -> dict[T.Literal["gaussian", "normalized"], list[str]]: """ dict: :attr:`_blur_type` mapped to cv2 Function required keyword arguments. """ - return dict(gaussian=["ksize", "sigmaX"], - normalized=["ksize"]) + return {"gaussian": ['ksize', 'sigmaX'], "normalized": ['ksize']} @property - def _kwarg_mapping(self) -> Dict[str, Union[int, Tuple[int, int]]]: + def _kwarg_mapping(self) -> dict[str, int | tuple[int, int]]: """ dict: cv2 function keyword arguments mapped to their parameters. """ - return dict(ksize=self._kernel_size, - sigmaX=self._sigma) + return {"ksize": self._kernel_size, "sigmaX": self._sigma} - def _get_kernel_size(self, kernel: Union[int, float], is_ratio: bool) -> int: + def _get_kernel_size(self, kernel: int | float, is_ratio: bool) -> int: """ Set the kernel size to absolute value. If :attr:`is_ratio` is ``True`` then the kernel size is calculated from the given ratio and @@ -999,7 +991,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods return kernel_size @staticmethod - def _get_kernel_tuple(kernel_size: int) -> Tuple[int, int]: + def _get_kernel_tuple(kernel_size: int) -> tuple[int, int]: """ Make sure kernel_size is odd and return it as a tuple. Parameters @@ -1017,7 +1009,7 @@ class BlurMask(): # pylint:disable=too-few-public-methods logger.trace(retval) # type: ignore return retval - def _get_kwargs(self) -> Dict[str, Union[int, Tuple[int, int]]]: + def _get_kwargs(self) -> dict[str, int | tuple[int, int]]: """ dict: the valid keyword arguments for the requested :attr:`_blur_type` """ retval = {kword: self._kwarg_mapping[kword] for kword in self._kwarg_requirements[self._blur_type]} @@ -1025,11 +1017,11 @@ class BlurMask(): # pylint:disable=too-few-public-methods return retval -_HASHES_SEEN: Dict[str, Dict[str, int]] = {} +_HASHES_SEEN: dict[str, dict[str, int]] = {} def update_legacy_png_header(filename: str, alignments: Alignments - ) -> Optional[PNGHeaderDict]: + ) -> PNGHeaderDict | None: """ Update a legacy extracted face from pre v2.1 alignments by placing the alignment data for the face in the png exif header for the given filename with the given alignment data. diff --git a/lib/cli/actions.py b/lib/cli/actions.py index 7c03caed..4b89b35c 100644 --- a/lib/cli/actions.py +++ b/lib/cli/actions.py @@ -7,7 +7,7 @@ as well as adding a mechanism for indicating to the GUI how specific options sho import argparse import os -from typing import Any, List, Optional, Tuple, Union +import typing as T # << FILE HANDLING >> @@ -69,7 +69,7 @@ class FileFullPaths(_FullPaths): >>> filetypes="video))" """ # pylint: disable=too-few-public-methods - def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None: + def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self.filetypes = filetypes @@ -111,7 +111,7 @@ class FilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods >>> filetypes="image", >>> nargs="+")) """ - def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None: + def __init__(self, *args, filetypes: str | None = None, **kwargs) -> None: if kwargs.get("nargs", None) is None: opt = kwargs["option_strings"] raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}") @@ -250,8 +250,8 @@ class ContextFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods, too-many-arguments def __init__(self, *args, - filetypes: Optional[str] = None, - action_option: Optional[str] = None, + filetypes: str | None = None, + action_option: str | None = None, **kwargs) -> None: opt = kwargs["option_strings"] if kwargs.get("nargs", None) is not None: @@ -263,7 +263,7 @@ class ContextFullPaths(FileFullPaths): super().__init__(*args, filetypes=filetypes, **kwargs) self.action_option = action_option - def _get_kwargs(self) -> List[Tuple[str, Any]]: + def _get_kwargs(self) -> list[tuple[str, T.Any]]: names = ["option_strings", "dest", "nargs", @@ -382,8 +382,8 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods """ def __init__(self, *args, - min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None, - rounding: Optional[int] = None, + min_max: tuple[int, int] | tuple[float, float] | None = None, + rounding: int | None = None, **kwargs) -> None: opt = kwargs["option_strings"] if kwargs.get("nargs", None) is not None: @@ -401,7 +401,7 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods self.min_max = min_max self.rounding = rounding - def _get_kwargs(self) -> List[Tuple[str, Any]]: + def _get_kwargs(self) -> list[tuple[str, T.Any]]: names = ["option_strings", "dest", "nargs", diff --git a/lib/cli/args.py b/lib/cli/args.py index 2b9143a8..0d2de626 100644 --- a/lib/cli/args.py +++ b/lib/cli/args.py @@ -8,8 +8,7 @@ import logging import re import sys import textwrap - -from typing import Any, Dict, List, NoReturn, Optional +import typing as T from lib.utils import get_backend from lib.gpu_stats import GPUStats @@ -30,7 +29,7 @@ _ = _LANG.gettext class FullHelpArgumentParser(argparse.ArgumentParser): """ Extends :class:`argparse.ArgumentParser` to output full help on bad arguments. """ - def error(self, message: str) -> NoReturn: + def error(self, message: str) -> T.NoReturn: self.print_help(sys.stderr) self.exit(2, f"{self.prog}: error: {message}\n") @@ -51,11 +50,11 @@ class SmartFormatter(argparse.HelpFormatter): prog: str, indent_increment: int = 2, max_help_position: int = 24, - width: Optional[int] = None) -> None: + width: int | None = None) -> None: super().__init__(prog, indent_increment, max_help_position, width) self._whitespace_matcher_limited = re.compile(r'[ \r\f\v]+', re.ASCII) - def _split_lines(self, text: str, width: int) -> List[str]: + def _split_lines(self, text: str, width: int) -> list[str]: """ Split the given text by the given display width. If the text is not prefixed with "R|" then the standard @@ -138,7 +137,7 @@ class FaceSwapArgs(): return "" @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Returns the argument list for the current command. The argument list should be a list of dictionaries pertaining to each option for a command. @@ -152,11 +151,11 @@ class FaceSwapArgs(): list The list of command line options for the given command """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] return argument_list @staticmethod - def get_optional_arguments() -> List[Dict[str, Any]]: + def get_optional_arguments() -> list[dict[str, T.Any]]: """ Returns the optional argument list for the current command. The optional arguments list is not always required, but is used when there are shared @@ -167,11 +166,11 @@ class FaceSwapArgs(): list The list of optional command line options for the given command """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] return argument_list @staticmethod - def _get_global_arguments() -> List[Dict[str, Any]]: + def _get_global_arguments() -> list[dict[str, T.Any]]: """ Returns the global Arguments list that are required for ALL commands in Faceswap. This method should NOT be overridden. @@ -181,7 +180,7 @@ class FaceSwapArgs(): list The list of global command line options for all Faceswap commands. """ - global_args: List[Dict[str, Any]] = [] + global_args: list[dict[str, T.Any]] = [] if _GPUS: global_args.append(dict( opts=("-X", "--exclude-gpus"), @@ -302,7 +301,7 @@ class ExtractConvertArgs(FaceSwapArgs): """ @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Returns the argument list for shared Extract and Convert arguments. Returns @@ -310,7 +309,7 @@ class ExtractConvertArgs(FaceSwapArgs): list The list of command line options for the given Extract and Convert """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] argument_list.append(dict( opts=("-i", "--input-dir"), action=DirOrFileFullPaths, @@ -362,7 +361,7 @@ class ExtractArgs(ExtractConvertArgs): "Extraction plugins can be configured in the 'Settings' Menu") @staticmethod - def get_optional_arguments() -> List[Dict[str, Any]]: + def get_optional_arguments() -> list[dict[str, T.Any]]: """ Returns the argument list unique to the Extract command. Returns @@ -377,7 +376,7 @@ class ExtractArgs(ExtractConvertArgs): default_detector = "s3fd" default_aligner = "fan" - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] argument_list.append(dict( opts=("-b", "--batch-mode"), action="store_true", @@ -658,7 +657,7 @@ class ConvertArgs(ExtractConvertArgs): "Conversion plugins can be configured in the 'Settings' Menu") @staticmethod - def get_optional_arguments() -> List[Dict[str, Any]]: + def get_optional_arguments() -> list[dict[str, T.Any]]: """ Returns the argument list unique to the Convert command. Returns @@ -667,7 +666,7 @@ class ConvertArgs(ExtractConvertArgs): The list of optional command line options for the Convert command """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] argument_list.append(dict( opts=("-ref", "--reference-video"), action=FileFullPaths, @@ -915,7 +914,7 @@ class TrainArgs(FaceSwapArgs): "Model plugins can be configured in the 'Settings' Menu") @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Returns the argument list for Train arguments. Returns @@ -923,7 +922,7 @@ class TrainArgs(FaceSwapArgs): list The list of command line options for training """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] argument_list.append(dict( opts=("-A", "--input-A"), action=DirFullPaths, @@ -1180,7 +1179,7 @@ class GuiArgs(FaceSwapArgs): """ Creates the command line arguments for the GUI. """ @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Returns the argument list for GUI arguments. Returns @@ -1188,7 +1187,7 @@ class GuiArgs(FaceSwapArgs): list The list of command line options for the GUI """ - argument_list: List[Dict[str, Any]] = [] + argument_list: list[dict[str, T.Any]] = [] argument_list.append(dict( opts=("-d", "--debug"), action="store_true", diff --git a/lib/cli/launcher.py b/lib/cli/launcher.py index 92193464..9dfa3c13 100644 --- a/lib/cli/launcher.py +++ b/lib/cli/launcher.py @@ -1,20 +1,22 @@ #!/usr/bin/env python3 """ Launches the correct script with the given Command Line Arguments """ +from __future__ import annotations import logging import os import platform import sys +import typing as T from importlib import import_module -from typing import Callable, TYPE_CHECKING from lib.gpu_stats import set_exclude_devices, GPUStats from lib.logger import crash_log, log_setup from lib.utils import (FaceswapError, get_backend, get_tf_version, safe_shutdown, set_backend, set_system_verbosity) -if TYPE_CHECKING: +if T.TYPE_CHECKING: import argparse + from collections.abc import Callable logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -99,8 +101,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods FaceswapError If Tensorflow is not found, or is not between versions 2.4 and 2.9 """ - directml_ver = rocm_ver = (2, 10) - min_ver = (2, 7) + min_ver = (2, 10) max_ver = (2, 10) try: import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import @@ -120,7 +121,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods self._handle_import_error(msg) tf_ver = get_tf_version() - backend = get_backend() if tf_ver < min_ver: msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version " f"{tf_ver} installed. Please upgrade Tensorflow.") @@ -129,14 +129,6 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version " f"{tf_ver} installed. Please downgrade Tensorflow.") self._handle_import_error(msg) - if backend == "directml" and tf_ver != directml_ver: - msg = (f"The supported Tensorflow version for DirectML cards is {directml_ver} but " - f"you have version {tf_ver} installed. Please install the correct version.") - self._handle_import_error(msg) - if backend == "rocm" and tf_ver != rocm_ver: - msg = (f"The supported Tensorflow version for ROCm cards is {rocm_ver} but " - f"you have version {tf_ver} installed. Please install the correct version.") - self._handle_import_error(msg) logger.debug("Installed Tensorflow Version: %s", tf_ver) @classmethod @@ -209,7 +201,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods "See https://support.apple.com/en-gb/HT201341") raise FaceswapError("No display detected. GUI mode has been disabled.") - def execute_script(self, arguments: "argparse.Namespace") -> None: + def execute_script(self, arguments: argparse.Namespace) -> None: """ Performs final set up and launches the requested :attr:`_command` with the given command line arguments. @@ -250,7 +242,7 @@ class ScriptExecutor(): # pylint:disable=too-few-public-methods finally: safe_shutdown(got_error=not success) - def _configure_backend(self, arguments: "argparse.Namespace") -> None: + def _configure_backend(self, arguments: argparse.Namespace) -> None: """ Configure the backend. Exclude any GPUs for use by Faceswap when requested. diff --git a/lib/config.py b/lib/config.py index c588026d..3619166f 100644 --- a/lib/config.py +++ b/lib/config.py @@ -13,7 +13,6 @@ from collections import OrderedDict from configparser import ConfigParser from dataclasses import dataclass from importlib import import_module -from typing import Dict, List, Optional, Tuple, Union from lib.utils import full_path_split @@ -21,16 +20,11 @@ from lib.utils import full_path_split _LANG = gettext.translation("lib.config", localedir="locales", fallback=True) _ = _LANG.gettext -# Can't type OrderedDict fully on Python 3.8 or lower -if sys.version_info < (3, 9): - OrderedDictSectionType = OrderedDict - OrderedDictItemType = OrderedDict -else: - OrderedDictSectionType = OrderedDict[str, "ConfigSection"] - OrderedDictItemType = OrderedDict[str, "ConfigItem"] +OrderedDictSectionType = OrderedDict[str, "ConfigSection"] +OrderedDictItemType = OrderedDict[str, "ConfigItem"] logger = logging.getLogger(__name__) # pylint: disable=invalid-name -ConfigValueType = Union[bool, int, float, List[str], str, None] +ConfigValueType = bool | int | float | list[str] | str | None @dataclass @@ -60,11 +54,11 @@ class ConfigItem: helptext: str datatype: type rounding: int - min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] - choices: Union[str, List[str]] + min_max: tuple[int, int] | tuple[float, float] | None + choices: str | list[str] gui_radio: bool fixed: bool - group: Optional[str] + group: str | None @dataclass @@ -84,7 +78,7 @@ class ConfigSection: class FaceswapConfig(): """ Config Items """ - def __init__(self, section: Optional[str], configfile: Optional[str] = None) -> None: + def __init__(self, section: str | None, configfile: str | None = None) -> None: """ Init Configuration Parameters @@ -106,11 +100,11 @@ class FaceswapConfig(): logger.debug("Initialized: %s", self.__class__.__name__) @property - def changeable_items(self) -> Dict[str, ConfigValueType]: + def changeable_items(self) -> dict[str, ConfigValueType]: """ Training only. Return a dict of config items with their set values for items that can be altered after the model has been created """ - retval: Dict[str, ConfigValueType] = {} + retval: dict[str, ConfigValueType] = {} sections = [sect for sect in self.config.sections() if sect.startswith("global")] all_sections = sections if self.section is None else sections + [self.section] for sect in all_sections: @@ -189,10 +183,10 @@ class FaceswapConfig(): logger.debug("Added defaults: %s", section) @property - def config_dict(self) -> Dict[str, ConfigValueType]: + def config_dict(self) -> dict[str, ConfigValueType]: """ dict: Collate global options and requested section into a dictionary with the correct data types """ - conf: Dict[str, ConfigValueType] = {} + conf: dict[str, ConfigValueType] = {} sections = [sect for sect in self.config.sections() if sect.startswith("global")] if self.section is not None: sections.append(self.section) @@ -240,7 +234,7 @@ class FaceswapConfig(): logger.debug("Returning item: (type: %s, value: %s)", datatype, retval) return retval - def _parse_list(self, section: str, option: str) -> List[str]: + def _parse_list(self, section: str, option: str) -> list[str]: """ Parse options that are stored as lists in the config file. These can be space or comma-separated items in the config file. They will be returned as a list of strings, regardless of what the final data type should be, so conversion from strings to other @@ -268,7 +262,7 @@ class FaceswapConfig(): raw_option, retval, section, option) return retval - def _get_config_file(self, configfile: Optional[str]) -> str: + def _get_config_file(self, configfile: str | None) -> str: """ Return the config file from the calling folder or the provided file Parameters @@ -309,17 +303,17 @@ class FaceswapConfig(): self.defaults[title] = ConfigSection(helptext=info, items=OrderedDict()) def add_item(self, - section: Optional[str] = None, - title: Optional[str] = None, + section: str | None = None, + title: str | None = None, datatype: type = str, default: ConfigValueType = None, - info: Optional[str] = None, - rounding: Optional[int] = None, - min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None, - choices: Optional[Union[str, List[str]]] = None, + info: str | None = None, + rounding: int | None = None, + min_max: tuple[int, int] | tuple[float, float] | None = None, + choices: str | list[str] | None = None, gui_radio: bool = False, fixed: bool = True, - group: Optional[str] = None) -> None: + group: str | None = None) -> None: """ Add a default item to a config section For int or float values, rounding and min_max must be set @@ -382,10 +376,10 @@ class FaceswapConfig(): @classmethod def _expand_helptext(cls, helptext: str, - choices: Union[str, List[str]], + choices: str | list[str], default: ConfigValueType, datatype: type, - min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]], + min_max: tuple[int, int] | tuple[float, float] | None, fixed: bool) -> str: """ Add extra helptext info from parameters """ helptext += "\n" @@ -437,7 +431,7 @@ class FaceswapConfig(): def insert_config_section(self, section: str, helptext: str, - config: Optional[ConfigParser] = None) -> None: + config: ConfigParser | None = None) -> None: """ Insert a section into the config Parameters @@ -464,7 +458,7 @@ class FaceswapConfig(): item: str, default: ConfigValueType, option: ConfigItem, - config: Optional[ConfigParser] = None) -> None: + config: ConfigParser | None = None) -> None: """ Insert an item into a config section Parameters diff --git a/lib/convert.py b/lib/convert.py index 885e4df1..fddecf20 100644 --- a/lib/convert.py +++ b/lib/convert.py @@ -1,23 +1,18 @@ #!/usr/bin/env python3 """ Converter for Faceswap """ - +from __future__ import annotations import logging -import sys +import typing as T from dataclasses import dataclass -from typing import Callable, cast, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 import numpy as np from plugins.plugin_loader import PluginLoader -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from argparse import Namespace + from collections.abc import Callable from lib.align.aligned_face import AlignedFace, CenteringType from lib.align.detected_face import DetectedFace from lib.config import FaceswapConfig @@ -46,10 +41,10 @@ class Adjustments: sharpening: :class:`~plugins.scaling._base.Adjustment`, Optional The selected mask processing plugin. Default: `None` """ - color: Optional["ColorAdjust"] = None - mask: Optional["MaskAdjust"] = None - seamless: Optional["SeamlessAdjust"] = None - sharpening: Optional["ScalingAdjust"] = None + color: ColorAdjust | None = None + mask: MaskAdjust | None = None + seamless: SeamlessAdjust | None = None + sharpening: ScalingAdjust | None = None class Converter(): @@ -81,11 +76,11 @@ class Converter(): def __init__(self, output_size: int, coverage_ratio: float, - centering: "CenteringType", + centering: CenteringType, draw_transparent: bool, - pre_encode: Optional[Callable[[np.ndarray], List[bytes]]], - arguments: "Namespace", - configfile: Optional[str] = None) -> None: + pre_encode: Callable[[np.ndarray], list[bytes]] | None, + arguments: Namespace, + configfile: str | None = None) -> None: logger.debug("Initializing %s: (output_size: %s, coverage_ratio: %s, centering: %s, " "draw_transparent: %s, pre_encode: %s, arguments: %s, configfile: %s)", self.__class__.__name__, output_size, coverage_ratio, centering, @@ -105,12 +100,12 @@ class Converter(): logger.debug("Initialized %s", self.__class__.__name__) @property - def cli_arguments(self) -> "Namespace": + def cli_arguments(self) -> Namespace: """:class:`argparse.Namespace`: The command line arguments passed to the convert process """ return self._args - def reinitialize(self, config: "FaceswapConfig") -> None: + def reinitialize(self, config: FaceswapConfig) -> None: """ Reinitialize this :class:`Converter`. Called as part of the :mod:`~tools.preview` tool. Resets all adjustments then loads the @@ -127,7 +122,7 @@ class Converter(): logger.debug("Reinitialized converter") def _load_plugins(self, - config: Optional["FaceswapConfig"] = None, + config: FaceswapConfig | None = None, disable_logging: bool = False) -> None: """ Load the requested adjustment plugins. @@ -169,7 +164,7 @@ class Converter(): self._adjustments.sharpening = sharpening logger.debug("Loaded plugins: %s", self._adjustments) - def process(self, in_queue: "EventQueue", out_queue: "EventQueue"): + def process(self, in_queue: EventQueue, out_queue: EventQueue): """ Main convert process. Takes items from the in queue, runs the relevant adjustments, patches faces to final frame @@ -188,7 +183,7 @@ class Converter(): in_queue, out_queue) log_once = False while True: - inbound: Union[Literal["EOF"], "ConvertItem", List["ConvertItem"]] = in_queue.get() + inbound: T.Literal["EOF"] | ConvertItem | list[ConvertItem] = in_queue.get() if inbound == "EOF": logger.debug("EOF Received") logger.debug("Patch queue finished") @@ -218,7 +213,7 @@ class Converter(): out_queue.put((item.inbound.filename, image)) logger.debug("Completed convert process") - def _patch_image(self, predicted: "ConvertItem") -> Union[np.ndarray, List[bytes]]: + def _patch_image(self, predicted: ConvertItem) -> np.ndarray | list[bytes]: """ Patch a swapped face onto a frame. Run selected adjustments and swap the faces in a frame. @@ -246,15 +241,15 @@ class Converter(): out=np.empty(patched_face.shape, dtype="uint8"), casting='unsafe') if self._writer_pre_encode is None: - retval: Union[np.ndarray, List[bytes]] = patched_face + retval: np.ndarray | list[bytes] = patched_face else: retval = self._writer_pre_encode(patched_face) logger.trace("Patched image: '%s'", predicted.inbound.filename) # type: ignore return retval def _get_new_image(self, - predicted: "ConvertItem", - frame_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: + predicted: ConvertItem, + frame_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]: """ Get the new face from the predictor and apply pre-warp manipulations. Applies any requested adjustments to the raw output of the Faceswap model @@ -308,9 +303,9 @@ class Converter(): def _pre_warp_adjustments(self, new_face: np.ndarray, - detected_face: "DetectedFace", - reference_face: "AlignedFace", - predicted_mask: Optional[np.ndarray]) -> np.ndarray: + detected_face: DetectedFace, + reference_face: AlignedFace, + predicted_mask: np.ndarray | None) -> np.ndarray: """ Run any requested adjustments that can be performed on the raw output from the Faceswap model. @@ -337,7 +332,7 @@ class Converter(): """ logger.trace("new_face shape: %s, predicted_mask shape: %s", # type: ignore new_face.shape, predicted_mask.shape if predicted_mask is not None else None) - old_face = cast(np.ndarray, reference_face.face)[..., :3] / 255.0 + old_face = T.cast(np.ndarray, reference_face.face)[..., :3] / 255.0 new_face, raw_mask = self._get_image_mask(new_face, detected_face, predicted_mask, @@ -351,9 +346,9 @@ class Converter(): def _get_image_mask(self, new_face: np.ndarray, - detected_face: "DetectedFace", - predicted_mask: Optional[np.ndarray], - reference_face: "AlignedFace") -> Tuple[np.ndarray, np.ndarray]: + detected_face: DetectedFace, + predicted_mask: np.ndarray | None, + reference_face: AlignedFace) -> tuple[np.ndarray, np.ndarray]: """ Return any selected image mask Places the requested mask into the new face's Alpha channel. diff --git a/lib/gpu_stats/_base.py b/lib/gpu_stats/_base.py index b45bd011..8953f134 100644 --- a/lib/gpu_stats/_base.py +++ b/lib/gpu_stats/_base.py @@ -5,11 +5,10 @@ from the :class:`_GPUStats` class contained here. """ import logging from dataclasses import dataclass -from typing import List, Optional from lib.utils import get_backend -_EXCLUDE_DEVICES: List[int] = [] +_EXCLUDE_DEVICES: list[int] = [] @dataclass @@ -29,11 +28,11 @@ class GPUInfo(): devices_active: list[int] List of integers representing the indices of the active GPU devices. """ - vram: List[int] - vram_free: List[int] + vram: list[int] + vram_free: list[int] driver: str - devices: List[str] - devices_active: List[int] + devices: list[str] + devices_active: list[int] @dataclass @@ -57,7 +56,7 @@ class BiggestGPUInfo(): total: float -def set_exclude_devices(devices: List[int]) -> None: +def set_exclude_devices(devices: list[int]) -> None: """ Add any explicitly selected GPU devices to the global list of devices to be excluded from use by Faceswap. @@ -89,19 +88,19 @@ class _GPUStats(): def __init__(self, log: bool = True) -> None: # Logger is held internally, as we don't want to log when obtaining system stats on crash # or when querying the backend for command line options - self._logger: Optional[logging.Logger] = logging.getLogger(__name__) if log else None + self._logger: logging.Logger | None = logging.getLogger(__name__) if log else None self._log("debug", f"Initializing {self.__class__.__name__}") self._is_initialized = False self._initialize() self._device_count: int = self._get_device_count() - self._active_devices: List[int] = self._get_active_devices() + self._active_devices: list[int] = self._get_active_devices() self._handles: list = self._get_handles() self._driver: str = self._get_driver() - self._device_names: List[str] = self._get_device_names() - self._vram: List[int] = self._get_vram() - self._vram_free: List[int] = self._get_free_vram() + self._device_names: list[str] = self._get_device_names() + self._vram: list[int] = self._get_vram() + self._vram_free: list[int] = self._get_free_vram() if get_backend() != "cpu" and not self._active_devices: self._log("warning", "No GPU detected") @@ -115,7 +114,7 @@ class _GPUStats(): return self._device_count @property - def cli_devices(self) -> List[str]: + def cli_devices(self) -> list[str]: """ list[str]: Formatted index: name text string for each GPU """ return [f"{idx}: {device}" for idx, device in enumerate(self._device_names)] @@ -167,7 +166,7 @@ class _GPUStats(): """ raise NotImplementedError() - def _get_active_devices(self) -> List[int]: + def _get_active_devices(self) -> list[int]: """ Obtain the indices of active GPUs (those that have not been explicitly excluded in the command line arguments). @@ -204,7 +203,7 @@ class _GPUStats(): """ raise NotImplementedError() - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Override to obtain the names of all connected GPUs. The quality of this information depends on the backend and OS being used, but it should be sufficient for identifying cards. @@ -217,7 +216,7 @@ class _GPUStats(): """ raise NotImplementedError() - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Override to obtain the total VRAM in Megabytes for each connected GPU. Returns @@ -228,7 +227,7 @@ class _GPUStats(): """ raise NotImplementedError() - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Override to obtain the amount of VRAM that is available, in Megabytes, for each connected GPU. diff --git a/lib/gpu_stats/apple_silicon.py b/lib/gpu_stats/apple_silicon.py index 11fcfab1..a8b08150 100644 --- a/lib/gpu_stats/apple_silicon.py +++ b/lib/gpu_stats/apple_silicon.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ Collects and returns Information on available Apple Silicon SoCs in Apple Macs. """ -from typing import Any, List +import typing as T import os import psutil @@ -35,7 +35,7 @@ class AppleSiliconStats(_GPUStats): """ def __init__(self, log: bool = True) -> None: # Following attribute set in :func:``_initialize`` - self._tf_devices: List[Any] = [] + self._tf_devices: list[T.Any] = [] super().__init__(log=log) @@ -142,7 +142,7 @@ class AppleSiliconStats(_GPUStats): self._log("debug", f"GPU Driver: {driver}") return driver - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of available Apple Silicon SoC(s) as identified in :attr:`_handles`. @@ -155,7 +155,7 @@ class AppleSiliconStats(_GPUStats): self._log("debug", f"GPU Devices: {names}") return names - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the VRAM in Megabytes for each available Apple Silicon SoC(s) as identified in :attr:`_handles`. @@ -175,7 +175,7 @@ class AppleSiliconStats(_GPUStats): self._log("debug", f"SoC RAM: {vram}") return vram - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of VRAM that is available, in Megabytes, for each available Apple Silicon SoC. diff --git a/lib/gpu_stats/cpu.py b/lib/gpu_stats/cpu.py index 23090828..ae20c96c 100644 --- a/lib/gpu_stats/cpu.py +++ b/lib/gpu_stats/cpu.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 """ Dummy functions for running faceswap on CPU. """ - - -from typing import List - from ._base import _GPUStats @@ -65,7 +61,7 @@ class CPUStats(_GPUStats): self._log("debug", f"GPU Driver: {driver}") return driver - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`. Returns @@ -73,11 +69,11 @@ class CPUStats(_GPUStats): list An empty list for CPU backends """ - names: List[str] = [] + names: list[str] = [] self._log("debug", f"GPU Devices: {names}") return names - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the RAM in Megabytes for the running system. Returns @@ -85,11 +81,11 @@ class CPUStats(_GPUStats): list An empty list for CPU backends """ - vram: List[int] = [] + vram: list[int] = [] self._log("debug", f"GPU VRAM: {vram}") return vram - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of RAM that is available, in Megabytes, for the running system. Returns @@ -97,6 +93,6 @@ class CPUStats(_GPUStats): list An empty list for CPU backends """ - vram: List[int] = [] + vram: list[int] = [] self._log("debug", f"GPU VRAM free: {vram}") return vram diff --git a/lib/gpu_stats/directml.py b/lib/gpu_stats/directml.py index d29fac04..17364bbb 100644 --- a/lib/gpu_stats/directml.py +++ b/lib/gpu_stats/directml.py @@ -1,19 +1,23 @@ #!/usr/bin/env python3 """ Collects and returns Information on DirectX 12 hardware devices for DirectML. """ +from __future__ import annotations import os import sys +import typing as T assert sys.platform == "win32" import ctypes from ctypes import POINTER, Structure, windll from dataclasses import dataclass from enum import Enum, IntEnum -from typing import Any, Callable, cast, List from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error from ._base import _GPUStats +if T.TYPE_CHECKING: + from collections.abc import Callable + # Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types # We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we # attach it directly and suck up the typing errors. @@ -314,7 +318,7 @@ class Adapters(): # pylint:disable=too-few-public-methods self._adapters = self._get_adapters() self._devices = self._process_adapters() - self._valid_adaptors: List[Device] = [] + self._valid_adaptors: list[Device] = [] self._log("debug", f"Initialized {self.__class__.__name__}") def _get_factory(self) -> ctypes._Pointer: @@ -334,12 +338,12 @@ class Adapters(): # pylint:disable=too-few-public-methods factory_func.restype = HRESULT handle = ctypes.c_void_p(0) factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access - retval = ctypes.POINTER(IDXGIFactory6)(cast(IDXGIFactory6, handle.value)) + retval = ctypes.POINTER(IDXGIFactory6)(T.cast(IDXGIFactory6, handle.value)) self._log("debug", f"factory: {retval}") return retval @property - def valid_adapters(self) -> List[Device]: + def valid_adapters(self) -> list[Device]: """ list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """ if self._valid_adaptors: return self._valid_adaptors @@ -354,7 +358,7 @@ class Adapters(): # pylint:disable=too-few-public-methods self._log("debug", f"valid_adaptors: {self._valid_adaptors}") return self._valid_adaptors - def _get_adapters(self) -> List[ctypes._Pointer]: + def _get_adapters(self) -> list[ctypes._Pointer]: """ Obtain DirectX 12 supporting hardware adapter objects and add a Device class for obtaining details @@ -376,7 +380,7 @@ class Adapters(): # pylint:disable=too-few-public-methods if success != 0: raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: " f"{hex(ctypes.c_ulong(success).value)}") - adapter = POINTER(IDXGIAdapter3)(cast(IDXGIAdapter3, handle.value)) + adapter = POINTER(IDXGIAdapter3)(T.cast(IDXGIAdapter3, handle.value)) self._log("debug", f"found adapter: {adapter}") retval.append(adapter) except COMError as err: @@ -392,7 +396,7 @@ class Adapters(): # pylint:disable=too-few-public-methods self._log("debug", f"adapters: {retval}") return retval - def _query_adapter(self, func: Callable[[Any], Any], *args: Any) -> None: + def _query_adapter(self, func: Callable[[T.Any], T.Any], *args: T.Any) -> None: """ Query an adapter function, logging if the HRESULT is not a success Parameters @@ -430,7 +434,7 @@ class Adapters(): # pylint:disable=too-few-public-methods LookupGUID.ID3D12Device) return success in (0, 1) - def _process_adapters(self) -> List[Device]: + def _process_adapters(self) -> list[Device]: """ Process the adapters to add discovered information. Returns @@ -485,21 +489,21 @@ class DirectML(_GPUStats): Default: ``True`` """ def __init__(self, log: bool = True) -> None: - self._devices: List[Device] = [] + self._devices: list[Device] = [] super().__init__(log=log) @property - def _all_vram(self) -> List[int]: + def _all_vram(self) -> list[int]: """ list: The VRAM of each GPU device that the DX API has discovered. """ return [int(device.description.DedicatedVideoMemory / (1024 * 1024)) for device in self._devices] @property - def names(self) -> List[str]: + def names(self) -> list[str]: """ list: The name of each GPU device that the DX API has discovered. """ return [device.description.Description for device in self._devices] - def _get_active_devices(self) -> List[int]: + def _get_active_devices(self) -> list[int]: """ Obtain the indices of active GPUs (those that have not been explicitly excluded by DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line arguments). @@ -517,7 +521,7 @@ class DirectML(_GPUStats): self._log("debug", f"Active GPU Devices: {devices}") return devices - def _get_devices(self) -> List[Device]: + def _get_devices(self) -> list[Device]: """ Obtain all detected DX API devices. Returns @@ -582,7 +586,7 @@ class DirectML(_GPUStats): self._log("debug", f"GPU Drivers: {drivers}") return drivers - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`. Returns @@ -594,7 +598,7 @@ class DirectML(_GPUStats): self._log("debug", f"GPU Devices: {names}") return names - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in :attr:`_handles`. @@ -607,7 +611,7 @@ class DirectML(_GPUStats): self._log("debug", f"GPU VRAM: {vram}") return vram - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX 12 supporting GPU. diff --git a/lib/gpu_stats/nvidia.py b/lib/gpu_stats/nvidia.py index 8f8e8cef..67038e9c 100644 --- a/lib/gpu_stats/nvidia.py +++ b/lib/gpu_stats/nvidia.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 """ Collects and returns Information on available Nvidia GPUs. """ import os -from typing import List import pynvml @@ -83,7 +82,7 @@ class NvidiaStats(_GPUStats): self._log("debug", f"GPU Device count: {retval}") return retval - def _get_active_devices(self) -> List[int]: + def _get_active_devices(self) -> list[int]: """ Obtain the indices of active GPUs (those that have not been explicitly excluded by CUDA_VISIBLE_DEVICES environment variable or explicitly excluded in the command line arguments). @@ -130,7 +129,7 @@ class NvidiaStats(_GPUStats): self._log("debug", f"GPU Driver: {driver}") return driver - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`. Returns @@ -143,7 +142,7 @@ class NvidiaStats(_GPUStats): self._log("debug", f"GPU Devices: {names}") return names - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in :attr:`_handles`. @@ -157,7 +156,7 @@ class NvidiaStats(_GPUStats): self._log("debug", f"GPU VRAM: {vram}") return vram - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia GPU. diff --git a/lib/gpu_stats/nvidia_apple.py b/lib/gpu_stats/nvidia_apple.py index ae6cb74c..acbcd93f 100644 --- a/lib/gpu_stats/nvidia_apple.py +++ b/lib/gpu_stats/nvidia_apple.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 """ Collects and returns Information on available Nvidia GPUs connected to Apple Macs. """ -from typing import List - import pynvx from lib.utils import FaceswapError @@ -92,7 +90,7 @@ class NvidiaAppleStats(_GPUStats): self._log("debug", f"GPU Driver: {driver}") return driver - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of connected Nvidia GPUs as identified in :attr:`_handles`. Returns @@ -105,7 +103,7 @@ class NvidiaAppleStats(_GPUStats): self._log("debug", f"GPU Devices: {names}") return names - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the VRAM in Megabytes for each connected Nvidia GPU as identified in :attr:`_handles`. @@ -120,7 +118,7 @@ class NvidiaAppleStats(_GPUStats): self._log("debug", f"GPU VRAM: {vram}") return vram - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of VRAM that is available, in Megabytes, for each connected Nvidia GPU. diff --git a/lib/gpu_stats/rocm.py b/lib/gpu_stats/rocm.py index c41e96b2..dca43b38 100644 --- a/lib/gpu_stats/rocm.py +++ b/lib/gpu_stats/rocm.py @@ -10,7 +10,6 @@ It is a good starting point but may need to be refined over time import os import re from subprocess import run -from typing import List from ._base import _GPUStats @@ -221,7 +220,7 @@ class ROCm(_GPUStats): """ def __init__(self, log: bool = True) -> None: self._vendor_id = "0x1002" # AMD VendorID - self._sysfs_paths: List[str] = [] + self._sysfs_paths: list[str] = [] super().__init__(log=log) def _from_sysfs_file(self, path: str) -> str: @@ -249,7 +248,7 @@ class ROCm(_GPUStats): val = "" return val - def _get_sysfs_paths(self) -> List[str]: + def _get_sysfs_paths(self) -> list[str]: """ Obtain a list of sysfs paths to AMD branded GPUs connected to the system Returns @@ -259,7 +258,7 @@ class ROCm(_GPUStats): """ base_dir = "/sys/class/drm/" - retval: List[str] = [] + retval: list[str] = [] if not os.path.exists(base_dir): self._log("warning", f"sysfs not found at '{base_dir}'") return retval @@ -347,7 +346,7 @@ class ROCm(_GPUStats): self._log("debug", f"GPU Drivers: {retval}") return retval - def _get_device_names(self) -> List[str]: + def _get_device_names(self) -> list[str]: """ Obtain the list of names of connected GPUs as identified in :attr:`_handles`. Returns @@ -383,7 +382,7 @@ class ROCm(_GPUStats): self._log("debug", f"Device names: {retval}") return retval - def _get_active_devices(self) -> List[int]: + def _get_active_devices(self) -> list[int]: """ Obtain the indices of active GPUs (those that have not been explicitly excluded by HIP_VISIBLE_DEVICES environment variable or explicitly excluded in the command line arguments). @@ -401,7 +400,7 @@ class ROCm(_GPUStats): self._log("debug", f"Active GPU Devices: {devices}") return devices - def _get_vram(self) -> List[int]: + def _get_vram(self) -> list[int]: """ Obtain the VRAM in Megabytes for each connected AMD GPU as identified in :attr:`_handles`. @@ -423,7 +422,7 @@ class ROCm(_GPUStats): self._log("debug", f"GPU VRAM: {retval}") return retval - def _get_free_vram(self) -> List[int]: + def _get_free_vram(self) -> list[int]: """ Obtain the amount of VRAM that is available, in Megabytes, for each connected AMD GPU. diff --git a/lib/gui/analysis/event_reader.py b/lib/gui/analysis/event_reader.py index c2376b3b..2a6fcc27 100644 --- a/lib/gui/analysis/event_reader.py +++ b/lib/gui/analysis/event_reader.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 """ Handles the loading and collation of events from Tensorflow event log files. """ - +from __future__ import annotations import logging import os -import sys +import typing as T import zlib from dataclasses import dataclass, field -from typing import Any, cast, Dict, Iterator, Generator, List, Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -17,11 +16,8 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module from lib.serializer import get_serializer -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - +if T.TYPE_CHECKING: + from collections.abc import Generator, Iterator logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -38,7 +34,7 @@ class EventData: The loss values collected for A and B sides for the event step """ timestamp: float = 0.0 - loss: List[float] = field(default_factory=list) + loss: list[float] = field(default_factory=list) class _LogFiles(): @@ -56,11 +52,11 @@ class _LogFiles(): logger.debug("Initialized: %s", self.__class__.__name__) @property - def session_ids(self) -> List[int]: + def session_ids(self) -> list[int]: """ list[int]: Sorted list of `ints` of available session ids. """ return list(sorted(self._filenames)) - def _get_log_filenames(self) -> Dict[int, str]: + def _get_log_filenames(self) -> dict[int, str]: """ Get the Tensorflow event filenames for all existing sessions. Returns @@ -69,7 +65,7 @@ class _LogFiles(): The full path of each log file for each training session id that has been run """ logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder) - retval: Dict[int, str] = {} + retval: dict[int, str] = {} for dirpath, _, filenames in os.walk(self._logs_folder): if not any(filename.startswith("events.out.tfevents") for filename in filenames): continue @@ -82,7 +78,7 @@ class _LogFiles(): return retval @classmethod - def _get_session_id(cls, folder: str) -> Optional[int]: + def _get_session_id(cls, folder: str) -> int | None: """ Obtain the session id for the given folder. Parameters @@ -103,7 +99,7 @@ class _LogFiles(): return retval @classmethod - def _get_log_filename(cls, folder: str, filenames: List[str]) -> str: + def _get_log_filename(cls, folder: str, filenames: list[str]) -> str: """ Obtain the session log file for the given folder. If multiple log files exist for the given folder, then the most recent log file is used, as earlier files are assumed to be obsolete. @@ -161,10 +157,10 @@ class _CacheData(): loss: :class:`np.ndarray` The loss values collected for A and B sides for the session """ - def __init__(self, labels: List[str], timestamps: np.ndarray, loss: np.ndarray) -> None: + def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None: self.labels = labels - self._loss = zlib.compress(cast(bytes, loss)) - self._timestamps = zlib.compress(cast(bytes, timestamps)) + self._loss = zlib.compress(T.cast(bytes, loss)) + self._timestamps = zlib.compress(T.cast(bytes, timestamps)) self._timestamps_shape = timestamps.shape self._loss_shape = loss.shape @@ -192,8 +188,8 @@ class _CacheData(): timestamps: :class:`numpy.ndarray` The latest timestamps to add to the cache """ - new_buffer: List[bytes] = [] - new_shapes: List[Tuple[int, ...]] = [] + new_buffer: list[bytes] = [] + new_shapes: list[tuple[int, ...]] = [] for data, buffer, dtype, shape in zip([timestamps, loss], [self._timestamps, self._loss], ["float64", "float32"], @@ -220,9 +216,9 @@ class _Cache(): """ Holds parsed Tensorflow log event data in a compressed cache in memory. """ def __init__(self) -> None: logger.debug("Initializing: %s", self.__class__.__name__) - self._data: Dict[int, _CacheData] = {} - self._carry_over: Dict[int, EventData] = {} - self._loss_labels: List[str] = [] + self._data: dict[int, _CacheData] = {} + self._carry_over: dict[int, EventData] = {} + self._loss_labels: list[str] = [] logger.debug("Initialized: %s", self.__class__.__name__) def is_cached(self, session_id: int) -> bool: @@ -242,8 +238,8 @@ class _Cache(): def cache_data(self, session_id: int, - data: Dict[int, EventData], - labels: List[str], + data: dict[int, EventData], + labels: list[str], is_live: bool = False) -> None: """ Add a full session's worth of event data to :attr:`_data`. @@ -278,8 +274,8 @@ class _Cache(): self._add_latest_live(session_id, loss, timestamps) def _to_numpy(self, - data: Dict[int, EventData], - is_live: bool) -> Tuple[np.ndarray, np.ndarray]: + data: dict[int, EventData], + is_live: bool) -> tuple[np.ndarray, np.ndarray]: """ Extract each individual step data into separate numpy arrays for loss and timestamps. Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays @@ -333,7 +329,7 @@ class _Cache(): return n_times, n_loss - def _collect_carry_over(self, data: Dict[int, EventData]) -> None: + def _collect_carry_over(self, data: dict[int, EventData]) -> None: """ For live data, collect carried over data from the previous update and merge into the current data dictionary. @@ -357,8 +353,8 @@ class _Cache(): logger.debug("Merged carry over data: %s", update) def _process_data(self, - data: Dict[int, EventData], - is_live: bool) -> Tuple[List[float], List[List[float]]]: + data: dict[int, EventData], + is_live: bool) -> tuple[list[float], list[list[float]]]: """ Process live update data. Live data requires different processing as often we will only have partial data for the @@ -383,8 +379,8 @@ class _Cache(): timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss) for idx in sorted(data)]) - l_loss: List[List[float]] = list(loss) - l_timestamps: List[float] = list(timestamps) + l_loss: list[list[float]] = list(loss) + l_timestamps: list[float] = list(timestamps) if len(l_loss[-1]) != len(self._loss_labels): logger.debug("Truncated loss found. loss count: %s", len(l_loss)) @@ -418,8 +414,8 @@ class _Cache(): self._data[session_id].add_live_data(timestamps, loss) - def get_data(self, session_id: int, metric: Literal["loss", "timestamps"] - ) -> Optional[Dict[int, Dict[str, Union[np.ndarray, List[str]]]]]: + def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"] + ) -> dict[int, dict[str, np.ndarray | list[str]]] | None: """ Retrieve the decompressed cached data from the cache for the given session id. Parameters @@ -445,10 +441,10 @@ class _Cache(): return None raw = {session_id: data} - retval: Dict[int, Dict[str, Union[np.ndarray, List[str]]]] = {} + retval: dict[int, dict[str, np.ndarray | list[str]]] = {} for idx, data in raw.items(): array = data.loss if metric == "loss" else data.timestamps - val: Dict[str, Union[np.ndarray, List[str]]] = {str(metric): array} + val: dict[str, np.ndarray | list[str]] = {str(metric): array} if metric == "loss": val["labels"] = data.labels retval[idx] = val @@ -488,7 +484,7 @@ class TensorBoardLogs(): logger.debug("Initialized: %s", self.__class__.__name__) @property - def session_ids(self) -> List[int]: + def session_ids(self) -> list[int]: """ list[int]: Sorted list of integers of available session ids. """ return self._log_files.session_ids @@ -539,7 +535,7 @@ class TensorBoardLogs(): parser = _EventParser(iterator, self._cache, live_data) parser.cache_events(session_id) - def _check_cache(self, session_id: Optional[int] = None) -> None: + def _check_cache(self, session_id: int | None = None) -> None: """ Check if the given session_id has been cached and if not, cache it. Parameters @@ -557,7 +553,7 @@ class TensorBoardLogs(): if not self._cache.is_cached(idx): self._cache_data(idx) - def get_loss(self, session_id: Optional[int] = None) -> Dict[int, Dict[str, np.ndarray]]: + def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]: """ Read the loss from the TensorBoard event logs Parameters @@ -573,7 +569,7 @@ class TensorBoardLogs(): and list of loss values for each step """ logger.debug("Getting loss: (session_id: %s)", session_id) - retval: Dict[int, Dict[str, np.ndarray]] = {} + retval: dict[int, dict[str, np.ndarray]] = {} for idx in [session_id] if session_id else self.session_ids: self._check_cache(idx) full_data = self._cache.get_data(idx, "loss") @@ -588,7 +584,7 @@ class TensorBoardLogs(): for key, val in retval.items()}) return retval - def get_timestamps(self, session_id: Optional[int] = None) -> Dict[int, np.ndarray]: + def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]: """ Read the timestamps from the TensorBoard logs. As loss timestamps are slightly different for each loss, we collect the timestamp from the @@ -608,7 +604,7 @@ class TensorBoardLogs(): logger.debug("Getting timestamps: (session_id: %s, is_training: %s)", session_id, self._is_training) - retval: Dict[int, np.ndarray] = {} + retval: dict[int, np.ndarray] = {} for idx in [session_id] if session_id else self.session_ids: self._check_cache(idx) data = self._cache.get_data(idx, "timestamps") @@ -640,7 +636,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods self._live_data = live_data self._cache = cache self._iterator = self._get_latest_live(iterator) if live_data else iterator - self._loss_labels: List[str] = [] + self._loss_labels: list[str] = [] logger.debug("Initialized: %s", self.__class__.__name__) @classmethod @@ -683,7 +679,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods The session id that the data is being cached for """ assert self._iterator is not None - data: Dict[int, EventData] = {} + data: dict[int, EventData] = {} try: for record in self._iterator: event = event_pb2.Event.FromString(record) # pylint:disable=no-member @@ -743,7 +739,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods logger.debug("Collated loss labels: %s", self._loss_labels) @classmethod - def _get_outputs(cls, model_config: Dict[str, Any]) -> np.ndarray: + def _get_outputs(cls, model_config: dict[str, T.Any]) -> np.ndarray: """ Obtain the output names, instance index and output index for the given model. If there is only a single output, the shape of the array is expanded to remain consistent diff --git a/lib/gui/analysis/stats.py b/lib/gui/analysis/stats.py index f3d7273e..a1874d87 100644 --- a/lib/gui/analysis/stats.py +++ b/lib/gui/analysis/stats.py @@ -5,15 +5,15 @@ Holds the globally loaded training session. This will either be a user selected the analysis tab) or the currently training session. """ - +from __future__ import annotations import logging -import time import os +import time +import typing as T import warnings from math import ceil from threading import Event -from typing import Any, cast, Dict, List, Optional, overload, Tuple, Union import numpy as np @@ -31,12 +31,12 @@ class GlobalSession(): """ def __init__(self) -> None: logger.debug("Initializing %s", self.__class__.__name__) - self._state: Dict[str, Any] = {} + self._state: dict[str, T.Any] = {} self._model_dir = "" self._model_name = "" - self._tb_logs: Optional[TensorBoardLogs] = None - self._summary: Optional[SessionsSummary] = None + self._tb_logs: TensorBoardLogs | None = None + self._summary: SessionsSummary | None = None self._is_training = False self._is_querying = Event() @@ -60,7 +60,7 @@ class GlobalSession(): return os.path.join(self._model_dir, self._model_name) @property - def batch_sizes(self) -> Dict[int, int]: + def batch_sizes(self) -> dict[int, int]: """ dict: The batch sizes for each session_id for the model. """ if not self._state: return {} @@ -68,7 +68,7 @@ class GlobalSession(): for sess_id, sess in self._state.get("sessions", {}).items()} @property - def full_summary(self) -> List[dict]: + def full_summary(self) -> list[dict]: """ list: List of dictionaries containing summary statistics for each session id. """ assert self._summary is not None return self._summary.get_summary_stats() @@ -83,7 +83,7 @@ class GlobalSession(): return self._state["sessions"][max_id]["no_logs"] @property - def session_ids(self) -> List[int]: + def session_ids(self) -> list[int]: """ list: The sorted list of all existing session ids in the state file """ if self._tb_logs is None: return [] @@ -164,7 +164,7 @@ class GlobalSession(): self._is_training = False - def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]: + def get_loss(self, session_id: int | None) -> dict[str, np.ndarray]: """ Obtain the loss values for the given session_id. Parameters @@ -186,11 +186,11 @@ class GlobalSession(): assert self._tb_logs is not None loss_dict = self._tb_logs.get_loss(session_id=session_id) if session_id is None: - all_loss: Dict[str, List[float]] = {} + all_loss: dict[str, list[float]] = {} for key in sorted(loss_dict): for loss_key, loss in loss_dict[key].items(): all_loss.setdefault(loss_key, []).extend(loss) - retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32") + retval: dict[str, np.ndarray] = {key: np.array(val, dtype="float32") for key, val in all_loss.items()} else: retval = loss_dict.get(session_id, {}) @@ -199,11 +199,11 @@ class GlobalSession(): self._is_querying.clear() return retval - @overload - def get_timestamps(self, session_id: None) -> Dict[int, np.ndarray]: + @T.overload + def get_timestamps(self, session_id: None) -> dict[int, np.ndarray]: ... - @overload + @T.overload def get_timestamps(self, session_id: int) -> np.ndarray: ... @@ -247,7 +247,7 @@ class GlobalSession(): continue break - def get_loss_keys(self, session_id: Optional[int]) -> List[str]: + def get_loss_keys(self, session_id: int | None) -> list[str]: """ Obtain the loss keys for the given session_id. Parameters @@ -268,7 +268,7 @@ class GlobalSession(): in self._tb_logs.get_loss(session_id=session_id).items()} if session_id is None: - retval: List[str] = list(set(loss_key + retval: list[str] = list(set(loss_key for session in loss_keys.values() for loss_key in session)) else: @@ -293,11 +293,11 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods self._session = session self._state = session._state - self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {} - self._per_session_stats: List[Dict[str, Any]] = [] + self._time_stats: dict[int, dict[str, float | int]] = {} + self._per_session_stats: list[dict[str, T.Any]] = [] logger.debug("Initialized %s", self.__class__.__name__) - def get_summary_stats(self) -> List[dict]: + def get_summary_stats(self) -> list[dict]: """ Compile the individual session statistics and calculate the total. Format the stats for display @@ -336,14 +336,14 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods sess_id: {"start_time": np.min(timestamps) if np.any(timestamps) else 0, "end_time": np.max(timestamps) if np.any(timestamps) else 0, "iterations": timestamps.shape[0] if np.any(timestamps) else 0} - for sess_id, timestamps in cast(Dict[int, np.ndarray], - self._session.get_timestamps(None)).items()} + for sess_id, timestamps in T.cast(dict[int, np.ndarray], + self._session.get_timestamps(None)).items()} elif _SESSION.is_training: logger.debug("Updating summary time stamps for training session") session_id = _SESSION.session_ids[-1] - latest = cast(np.ndarray, self._session.get_timestamps(session_id)) + latest = T.cast(np.ndarray, self._session.get_timestamps(session_id)) self._time_stats[session_id] = { "start_time": np.min(latest) if np.any(latest) else 0, @@ -392,7 +392,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods / stats["elapsed"] if stats["elapsed"] > 0 else 0) logger.debug("per_session_stats: %s", self._per_session_stats) - def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]: + def _collate_stats(self, session_id: int) -> dict[str, int | float]: """ Collate the session summary statistics for the given session ID. Parameters @@ -422,7 +422,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods logger.debug(retval) return retval - def _total_stats(self) -> Dict[str, Union[str, int, float]]: + def _total_stats(self) -> dict[str, str | int | float]: """ Compile the Totals stats. Totals are fully calculated each time as they will change on the basis of the training session. @@ -459,7 +459,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods logger.debug(totals) return totals - def _format_stats(self, compiled_stats: List[dict]) -> List[dict]: + def _format_stats(self, compiled_stats: list[dict]) -> list[dict]: """ Format for the incoming list of statistics for display. Parameters @@ -489,7 +489,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods return retval @classmethod - def _convert_time(cls, timestamp: float) -> Tuple[str, str, str]: + def _convert_time(cls, timestamp: float) -> tuple[str, str, str]: """ Convert time stamp to total hours, minutes and seconds. Parameters @@ -534,8 +534,8 @@ class Calculations(): """ def __init__(self, session_id, display: str = "loss", - loss_keys: Union[List[str], str] = "loss", - selections: Union[List[str], str] = "raw", + loss_keys: list[str] | str = "loss", + selections: list[str] | str = "raw", avg_samples: int = 500, smooth_amount: float = 0.90, flatten_outliers: bool = False) -> None: @@ -552,13 +552,13 @@ class Calculations(): self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys] self._selections = selections if isinstance(selections, list) else [selections] self._is_totals = session_id is None - self._args: Dict[str, Union[int, float]] = {"avg_samples": avg_samples, - "smooth_amount": smooth_amount, - "flatten_outliers": flatten_outliers} + self._args: dict[str, int | float] = {"avg_samples": avg_samples, + "smooth_amount": smooth_amount, + "flatten_outliers": flatten_outliers} self._iterations = 0 self._limit = 0 self._start_iteration = 0 - self._stats: Dict[str, np.ndarray] = {} + self._stats: dict[str, np.ndarray] = {} self.refresh() logger.debug("Initialized %s", self.__class__.__name__) @@ -573,11 +573,11 @@ class Calculations(): return self._start_iteration @property - def stats(self) -> Dict[str, np.ndarray]: + def stats(self) -> dict[str, np.ndarray]: """ dict: The final calculated statistics """ return self._stats - def refresh(self) -> Optional["Calculations"]: + def refresh(self) -> Calculations | None: """ Refresh the stats """ logger.debug("Refreshing") if not _SESSION.is_loaded: @@ -736,7 +736,8 @@ class Calculations(): """ logger.debug("Calculating rate") batch_size = _SESSION.batch_sizes[self._session_id] * 2 - retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id))) + retval = batch_size / np.diff(T.cast(np.ndarray, + _SESSION.get_timestamps(self._session_id))) logger.debug("Calculated rate: Item_count: %s", len(retval)) return retval @@ -757,7 +758,7 @@ class Calculations(): logger.debug("Calculating totals rate") batchsizes = _SESSION.batch_sizes total_timestamps = _SESSION.get_timestamps(None) - rate: List[float] = [] + rate: list[float] = [] for sess_id in sorted(total_timestamps.keys()): batchsize = batchsizes[sess_id] timestamps = total_timestamps[sess_id] @@ -797,7 +798,7 @@ class Calculations(): The moving average for the given data """ logger.debug("Calculating Average. Data points: %s", len(data)) - window = cast(int, self._args["avg_samples"]) + window = T.cast(int, self._args["avg_samples"]) pad = ceil(window / 2) datapoints = data.shape[0] @@ -953,7 +954,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods def _ewma_vectorized(self, data: np.ndarray, out: np.ndarray, - offset: Optional[float] = None) -> None: + offset: float | None = None) -> None: """ Calculates the exponential moving average over a vector. Will fail for large inputs. The result is processed in place into the array passed to the `out` parameter diff --git a/lib/gui/control_helper.py b/lib/gui/control_helper.py index 0436109f..39063700 100644 --- a/lib/gui/control_helper.py +++ b/lib/gui/control_helper.py @@ -5,10 +5,10 @@ import logging import re import tkinter as tk +import typing as T from tkinter import colorchooser, ttk from itertools import zip_longest from functools import partial -from typing import Any, Dict from _tkinter import Tcl_Obj, TclError @@ -24,7 +24,9 @@ _ = _LANG.gettext # We store Tooltips, ContextMenus and Commands globally when they are created # Because we need to add them back to newly cloned widgets (they are not easily accessible from # original config or are prone to getting destroyed when the original widget is destroyed) -_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={}) +_RECREATE_OBJECTS: dict[str, dict[str, T.Any]] = {"tooltips": {}, + "commands": {}, + "contextmenus": {}} def _get_tooltip(widget, text=None, text_variable=None): @@ -154,17 +156,17 @@ class ControlPanelOption(): self.dtype = dtype self.sysbrowser = sysbrowser self._command = command - self._options = dict(title=title, - subgroup=subgroup, - group=group, - default=default, - initial_value=initial_value, - choices=choices, - is_radio=is_radio, - is_multi_option=is_multi_option, - rounding=rounding, - min_max=min_max, - helptext=helptext) + self._options = {"title": title, + "subgroup": subgroup, + "group": group, + "default": default, + "initial_value": initial_value, + "choices": choices, + "is_radio": is_radio, + "is_multi_option": is_multi_option, + "rounding": rounding, + "min_max": min_max, + "helptext": helptext} self.control = self.get_control() self.tk_var = self.get_tk_var(initial_value, track_modified) logger.debug("Initialized %s", self.__class__.__name__) @@ -421,7 +423,7 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors self.group_frames = {} self._sub_group_frames = {} - canvas_kwargs = dict(bd=0, highlightthickness=0, bg=self._theme["panel_background"]) + canvas_kwargs = {"bd": 0, "highlightthickness": 0, "bg": self._theme["panel_background"]} self._canvas = tk.Canvas(self, **canvas_kwargs) self._canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) @@ -525,8 +527,8 @@ class ControlPanel(ttk.Frame): # pylint:disable=too-many-ancestors group_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5, anchor=tk.NW) - self.group_frames[group] = dict(frame=retval, - chkbtns=self.checkbuttons_frame(retval)) + self.group_frames[group] = {"frame": retval, + "chkbtns": self.checkbuttons_frame(retval)} group_frame = self.group_frames[group] return group_frame @@ -720,12 +722,12 @@ class AutoFillContainer(): """ retval = {} if widget.__class__.__name__ == "MultiOption": - retval = dict(value=widget._value, # pylint:disable=protected-access - variable=widget._master_variable) # pylint:disable=protected-access + retval = {"value": widget._value, # pylint:disable=protected-access + "variable": widget._master_variable} # pylint:disable=protected-access elif widget.__class__.__name__ == "ToggledFrame": # Toggled Frames need to have their variable tracked - retval = dict(text=widget._text, # pylint:disable=protected-access - toggle_var=widget._toggle_var) # pylint:disable=protected-access + retval = {"text": widget._text, # pylint:disable=protected-access + "toggle_var": widget._toggle_var} # pylint:disable=protected-access return retval def get_all_children_config(self, widget, child_list): @@ -988,7 +990,7 @@ class ControlBuilder(): if self.option.control != ttk.Checkbutton: ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) if self.option.helptext is not None and not self.helpset: - tooltip_kwargs = dict(text=self.option.helptext) + tooltip_kwargs = {"text": self.option.helptext} if self.option.sysbrowser is not None: tooltip_kwargs["text_variable"] = self.option.tk_var _get_tooltip(ctl, **tooltip_kwargs) @@ -1071,7 +1073,7 @@ class ControlBuilder(): "rounding: %s, min_max: %s)", self.option.name, self.option.dtype, self.option.rounding, self.option.min_max) validate = self.slider_check_int if self.option.dtype == int else self.slider_check_float - vcmd = (self.frame.register(validate)) + vcmd = self.frame.register(validate) tbox = tk.Entry(self.frame, width=8, textvariable=self.option.tk_var, @@ -1246,15 +1248,15 @@ class FileBrowser(): @property def helptext(self): """ Dict containing tooltip text for buttons """ - retval = dict(folder=_("Select a folder..."), - load=_("Select a file..."), - load2=_("Select a file..."), - picture=_("Select a folder of images..."), - video=_("Select a video..."), - model=_("Select a model folder..."), - multi_load=_("Select one or more files..."), - context=_("Select a file or folder..."), - save_as=_("Select a save location...")) + retval = {"folder": _("Select a folder..."), + "load": _("Select a file..."), + "load2": _("Select a file..."), + "picture": _("Select a folder of images..."), + "video": _("Select a video..."), + "model": _("Select a model folder..."), + "multi_load": _("Select one or more files..."), + "context": _("Select a file or folder..."), + "save_as": _("Select a save location...")} return retval @staticmethod diff --git a/lib/gui/display_command.py b/lib/gui/display_command.py index b973d34b..ab129349 100644 --- a/lib/gui/display_command.py +++ b/lib/gui/display_command.py @@ -4,11 +4,10 @@ import datetime import gettext import logging import os -import sys import tkinter as tk +import typing as T from tkinter import ttk -from typing import Dict, Optional, Tuple from lib.training.preview_tk import PreviewTk @@ -19,11 +18,6 @@ from .analysis import Calculations, Session from .control_helper import set_slider_rounding from .utils import FileHandler, get_config, get_images, preview_trigger -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - logger = logging.getLogger(__name__) # pylint: disable=invalid-name # LOCALES @@ -92,7 +86,7 @@ class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors logger.debug("Initializing %s (args: %s, kwargs: %s)", self.__class__.__name__, args, kwargs) self._preview = get_images().preview_train - self._display: Optional[PreviewTk] = None + self._display: PreviewTk | None = None super().__init__(*args, **kwargs) logger.debug("Initialized %s", self.__class__.__name__) @@ -177,9 +171,9 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors tab_name: str, helptext: str, wait_time: int, - command: Optional[str] = None) -> None: - self._trace_vars: Dict[Literal["smoothgraph", "display_iterations"], - Tuple[tk.BooleanVar, str]] = {} + command: str | None = None) -> None: + self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"], + tuple[tk.BooleanVar, str]] = {} super().__init__(parent, tab_name, helptext, wait_time, command) def set_vars(self) -> None: @@ -446,7 +440,7 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors def _add_trace_variables(self) -> None: """ Add tracing for when the option sliders are updated, for updating the graph. """ - for name, action in zip(get_args(Literal["smoothgraph", "display_iterations"]), + for name, action in zip(T.get_args(T.Literal["smoothgraph", "display_iterations"]), (self._smooth_amount_callback, self._iteration_limit_callback)): var = self.vars[name] if name not in self._trace_vars: diff --git a/lib/gui/display_graph.py b/lib/gui/display_graph.py index 8a503e5c..3dd1511f 100755 --- a/lib/gui/display_graph.py +++ b/lib/gui/display_graph.py @@ -1,12 +1,13 @@ #!/usr/bin python3 """ Graph functions for Display Frame area of the Faceswap GUI """ +from __future__ import annotations import datetime import logging import os import tkinter as tk +import typing as T from tkinter import ttk -from typing import cast, Union, List, Optional, Tuple, TYPE_CHECKING from math import ceil, floor import numpy as np @@ -20,7 +21,7 @@ from matplotlib.backend_bases import NavigationToolbar2 from .custom_widgets import Tooltip from .utils import get_config, get_images, LongRunningTask -if TYPE_CHECKING: +if T.TYPE_CHECKING: from matplotlib.lines import Line2D matplotlib.use("TkAgg") @@ -49,8 +50,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors self._ylabel = ylabel self._colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys", "copper", "summer", "bone", "hot", "cool", "pink", "Wistia", "spring", "winter"] - self._lines: List["Line2D"] = [] - self._toolbar: Optional["NavigationToolbar"] = None + self._lines: list[Line2D] = [] + self._toolbar: "NavigationToolbar" | None = None self._fig = Figure(figsize=(4, 4), dpi=75) self._ax1 = self._fig.add_subplot(1, 1, 1) @@ -129,7 +130,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors self._ax1.set_ylim(0.00, 100.0) self._ax1.set_xlim(0, 1) - def _axes_limits_set(self, data: List[float]) -> None: + def _axes_limits_set(self, data: list[float]) -> None: """ Set the axes limits. Parameters @@ -154,7 +155,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors self._axes_limits_set_default() @staticmethod - def _axes_data_get_min_max(data: List[float]) -> Tuple[float, float]: + def _axes_data_get_min_max(data: list[float]) -> tuple[float, float]: """ Obtain the minimum and maximum values for the y-axis from the given data points. Parameters @@ -188,7 +189,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors logger.debug("yscale: '%s'", scale) self._ax1.set_yscale(scale) - def _lines_sort(self, keys: List[str]) -> List[List[Union[str, int, Tuple[float]]]]: + def _lines_sort(self, keys: list[str]) -> list[list[str | int | tuple[float]]]: """ Sort the data keys into consistent order and set line color map and line width. Parameters @@ -202,8 +203,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors A list of loss keys with their corresponding line formatting and color information """ logger.trace("Sorting lines") # type:ignore[attr-defined] - raw_lines: List[List[str]] = [] - sorted_lines: List[List[str]] = [] + raw_lines: list[list[str]] = [] + sorted_lines: list[list[str]] = [] for key in sorted(keys): title = key.replace("_", " ").title() if key.startswith("raw"): @@ -217,7 +218,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors return lines @staticmethod - def _lines_groupsize(raw_lines: List[List[str]], sorted_lines: List[List[str]]) -> int: + def _lines_groupsize(raw_lines: list[list[str]], sorted_lines: list[list[str]]) -> int: """ Get the number of items in each group. If raw data isn't selected, then check the length of remaining groups until something is @@ -246,8 +247,8 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors return groupsize def _lines_style(self, - lines: List[List[str]], - groupsize: int) -> List[List[Union[str, int, Tuple[float]]]]: + lines: list[list[str]], + groupsize: int) -> list[list[str | int | tuple[float]]]: """ Obtain the color map and line width for each group. Parameters @@ -266,13 +267,13 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors groups = int(len(lines) / groupsize) colours = self._lines_create_colors(groupsize, groups) widths = list(range(1, groups + 1)) - retval = cast(List[List[Union[str, int, Tuple[float]]]], lines) + retval = T.cast(list[list[str | int | tuple[float]]], lines) for idx, item in enumerate(retval): linewidth = widths[idx // groupsize] item.extend((linewidth, colours[idx])) return retval - def _lines_create_colors(self, groupsize: int, groups: int) -> List[Tuple[float]]: + def _lines_create_colors(self, groupsize: int, groups: int) -> list[tuple[float]]: """ Create the color maps. Parameters @@ -336,8 +337,8 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: super().__init__(parent, data, ylabel) - self._thread: Optional[LongRunningTask] = None # Thread for LongRunningTask - self._displayed_keys: List[str] = [] + self._thread: LongRunningTask | None = None # Thread for LongRunningTask + self._displayed_keys: list[str] = [] self._add_callback() def _add_callback(self) -> None: @@ -352,7 +353,7 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors def refresh(self, *args) -> None: # pylint: disable=unused-argument """ Read the latest loss data and apply to current graph """ - refresh_var = cast(tk.BooleanVar, get_config().tk_vars.refresh_graph) + refresh_var = T.cast(tk.BooleanVar, get_config().tk_vars.refresh_graph) if not refresh_var.get() and self._thread is None: return @@ -533,7 +534,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances text: str, image_file: str, toggle: bool, - command) -> Union[ttk.Button, ttk.Checkbutton]: + command) -> ttk.Button | ttk.Checkbutton: """ Override the default button method to use our icons and ttk widgets for consistent GUI layout. @@ -563,10 +564,10 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances img = get_images().icons[icon] if not toggle: - btn: Union[ttk.Button, ttk.Checkbutton] = ttk.Button(frame, - text=text, - image=img, - command=command) + btn: ttk.Button | ttk.Checkbutton = ttk.Button(frame, + text=text, + image=img, + command=command) else: var = tk.IntVar(master=frame) btn = ttk.Checkbutton(frame, text=text, image=img, command=command, variable=var) diff --git a/lib/gui/menu.py b/lib/gui/menu.py index 0677c31f..694655b0 100644 --- a/lib/gui/menu.py +++ b/lib/gui/menu.py @@ -1,6 +1,6 @@ #!/usr/bin python3 """ The Menu Bars for faceswap GUI """ - +from __future__ import annotations import gettext import locale import logging @@ -33,7 +33,7 @@ _ = _LANG.gettext _WORKING_DIR = os.path.dirname(os.path.realpath(sys.argv[0])) -_RESOURCES: T.List[T.Tuple[str, str]] = [ +_RESOURCES: list[tuple[str, str]] = [ (_("faceswap.dev - Guides and Forum"), "https://www.faceswap.dev"), (_("Patreon - Support this project"), "https://www.patreon.com/faceswap"), (_("Discord - The FaceSwap Discord server"), "https://discord.gg/VasFUAy"), @@ -48,7 +48,7 @@ class MainMenuBar(tk.Menu): # pylint:disable=too-many-ancestors master: :class:`tkinter.Tk` The root tkinter object """ - def __init__(self, master: "FaceswapGui") -> None: + def __init__(self, master: FaceswapGui) -> None: logger.debug("Initializing %s", self.__class__.__name__) super().__init__(master) self.root = master @@ -431,7 +431,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors return True @classmethod - def _get_branches(cls) -> T.Optional[str]: + def _get_branches(cls) -> str | None: """ Get the available github branches Returns @@ -453,7 +453,7 @@ class HelpMenu(tk.Menu): # pylint:disable=too-many-ancestors return stdout.decode(locale.getpreferredencoding(), errors="replace") @classmethod - def _filter_branches(cls, stdout: str) -> T.List[str]: + def _filter_branches(cls, stdout: str) -> list[str]: """ Filter the branches, remove duplicates and the current branch and return a sorted list. @@ -548,7 +548,7 @@ class TaskBar(ttk.Frame): # pylint: disable=too-many-ancestors self._section_separator() @classmethod - def _loader_and_kwargs(cls, btntype: str) -> T.Tuple[str, T.Dict[str, bool]]: + def _loader_and_kwargs(cls, btntype: str) -> tuple[str, dict[str, bool]]: """ Get the loader name and key word arguments for the given button type Parameters diff --git a/lib/gui/popup_configure.py b/lib/gui/popup_configure.py index e525b639..ea361a7d 100644 --- a/lib/gui/popup_configure.py +++ b/lib/gui/popup_configure.py @@ -1,6 +1,6 @@ #!/usr/bin python3 """ The pop-up window of the Faceswap GUI for the setting of configuration options. """ - +from __future__ import annotations from collections import OrderedDict from configparser import ConfigParser import gettext @@ -9,7 +9,8 @@ import os import sys import tkinter as tk from tkinter import ttk -from typing import Dict, TYPE_CHECKING +import typing as T + from importlib import import_module from lib.serializer import get_serializer @@ -18,7 +19,7 @@ from .control_helper import ControlPanel, ControlPanelOption from .custom_widgets import Tooltip from .utils import FileHandler, get_config, get_images, PATHCACHE -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.config import FaceswapConfig logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -124,7 +125,7 @@ class _ConfigurePlugins(tk.Toplevel): super().__init__() self._root = get_config().root self._set_geometry() - self._tk_vars = dict(header=tk.StringVar()) + self._tk_vars = {"header": tk.StringVar()} theme = {**get_config().user_theme["group_panel"], **get_config().user_theme["group_settings"]} @@ -402,7 +403,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors """ def __init__(self, top_level, parent, configurations, tree, theme): super().__init__(parent) - self._configs: Dict[str, "FaceswapConfig"] = configurations + self._configs: dict[str, FaceswapConfig] = configurations self._theme = theme self._tree = tree self._vars = {} @@ -443,7 +444,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors sect = section.split(".")[-1] # Elevate global to root key = plugin if sect == "global" else f"{plugin}|{category}|{sect}" - retval[key] = dict(helptext=None, options=OrderedDict()) + retval[key] = {"helptext": None, "options": OrderedDict()} retval[key]["helptext"] = conf.defaults[section].helptext for option, params in conf.defaults[section].items.items(): @@ -632,7 +633,7 @@ class DisplayArea(ttk.Frame): # pylint:disable=too-many-ancestors def _get_new_config(self, page_only: bool, - config: "FaceswapConfig", + config: FaceswapConfig, category: str, lookup: str) -> ConfigParser: """ Obtain a new configuration file for saving @@ -812,9 +813,9 @@ class _Presets(): return None args = ("save_filename", "json") if action == "save" else ("filename", "json") - kwargs = dict(title=f"{action.title()} Preset...", - initial_folder=self._preset_path, - parent=self._parent) + kwargs = {"title": f"{action.title()} Preset...", + "initial_folder": self._preset_path, + "parent": self._parent} if action == "save": kwargs["initial_file"] = self._get_initial_filename() diff --git a/lib/gui/popup_session.py b/lib/gui/popup_session.py index d1511e59..2d0162f0 100644 --- a/lib/gui/popup_session.py +++ b/lib/gui/popup_session.py @@ -8,7 +8,6 @@ import tkinter as tk from dataclasses import dataclass, field from tkinter import ttk -from typing import Dict, List, Optional, Tuple, Type, Union from .control_helper import ControlBuilder, ControlPanelOption from .custom_widgets import Tooltip @@ -66,7 +65,7 @@ class SessionTKVars: outliers: tk.BooleanVar avgiterations: tk.IntVar smoothamount: tk.DoubleVar - loss_keys: Dict[str, tk.BooleanVar] = field(default_factory=dict) + loss_keys: dict[str, tk.BooleanVar] = field(default_factory=dict) class SessionPopUp(tk.Toplevel): @@ -82,13 +81,13 @@ class SessionPopUp(tk.Toplevel): logger.debug("Initializing: %s: (session_id: %s, data_points: %s)", self.__class__.__name__, session_id, data_points) super().__init__() - self._thread: Optional[LongRunningTask] = None # Thread for loading data in background + self._thread: LongRunningTask | None = None # Thread for loading data in background self._default_view = "avg" if data_points > 1000 else "smoothed" self._session_id = None if session_id == "Total" else int(session_id) self._graph_frame = ttk.Frame(self) - self._graph: Optional[SessionGraph] = None - self._display_data: Optional[Calculations] = None + self._graph: SessionGraph | None = None + self._display_data: Calculations | None = None self._vars = self._set_vars() @@ -172,7 +171,7 @@ class SessionPopUp(tk.Toplevel): The frame that the options reside in """ logger.debug("Building Combo boxes") - choices = dict(Display=("Loss", "Rate"), Scale=("Linear", "Log")) + choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")} for item in ["Display", "Scale"]: var: tk.StringVar = getattr(self._vars, item.lower()) @@ -273,11 +272,11 @@ class SessionPopUp(tk.Toplevel): logger.debug("Building Slider Controls") for item in ("avgiterations", "smoothamount"): if item == "avgiterations": - dtype: Union[Type[int], Type[float]] = int + dtype: type[int] | type[float] = int text = "Iterations to Average:" - default: Union[int, float] = 500 + default: int | float = 500 rounding = 25 - min_max: Tuple[int, Union[int, float]] = (25, 2500) + min_max: tuple[int, int | float] = (25, 2500) elif item == "smoothamount": dtype = float text = "Smoothing Amount:" @@ -404,20 +403,20 @@ class SessionPopUp(tk.Toplevel): str The help text for the given action """ - lookup = dict( - reload=_("Refresh graph"), - save=_("Save display data to csv"), - avgiterations=_("Number of data points to sample for rolling average"), - smoothamount=_("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum " - "smoothing"), - outliers=_("Flatten data points that fall more than 1 standard deviation from the " - "mean to the mean value."), - avg=_("Display rolling average of the data"), - smoothed=_("Smooth the data"), - raw=_("Display raw data"), - trend=_("Display polynormal data trend"), - display=_("Set the data to display"), - scale=_("Change y-axis scale")) + lookup = { + "reload": _("Refresh graph"), + "save": _("Save display data to csv"), + "avgiterations": _("Number of data points to sample for rolling average"), + "smoothamount": _("Set the smoothing amount. 0 is no smoothing, 0.99 is maximum " + "smoothing"), + "outliers": _("Flatten data points that fall more than 1 standard deviation from the " + "mean to the mean value."), + "avg": _("Display rolling average of the data"), + "smoothed": _("Smooth the data"), + "raw": _("Display raw data"), + "trend": _("Display polynormal data trend"), + "display": _("Set the data to display"), + "scale": _("Change y-axis scale")} return lookup.get(action.lower(), "") def _compile_display_data(self) -> bool: @@ -446,13 +445,13 @@ class SessionPopUp(tk.Toplevel): self._lbl_loading.pack(fill=tk.BOTH, expand=True) self.update_idletasks() - kwargs = dict(session_id=self._session_id, - display=self._vars.display.get(), - loss_keys=loss_keys, - selections=selections, - avg_samples=self._vars.avgiterations.get(), - smooth_amount=self._vars.smoothamount.get(), - flatten_outliers=self._vars.outliers.get()) + kwargs = {"session_id": self._session_id, + "display": self._vars.display.get(), + "loss_keys": loss_keys, + "selections": selections, + "avg_samples": self._vars.avgiterations.get(), + "smooth_amount": self._vars.smoothamount.get(), + "flatten_outliers": self._vars.outliers.get()} self._thread = LongRunningTask(target=self._get_display_data, kwargs=kwargs, widget=self) @@ -491,7 +490,7 @@ class SessionPopUp(tk.Toplevel): """ return Calculations(**kwargs) - def _check_valid_selection(self, loss_keys: List[str], selections: List[str]) -> bool: + def _check_valid_selection(self, loss_keys: list[str], selections: list[str]) -> bool: """ Check that there will be data to display. Parameters @@ -530,7 +529,7 @@ class SessionPopUp(tk.Toplevel): return False return True - def _selections_to_list(self) -> List[str]: + def _selections_to_list(self) -> list[str]: """ Compile checkbox selections to a list. Returns diff --git a/lib/gui/utils/config.py b/lib/gui/utils/config.py index 3d4096a3..58e8152c 100644 --- a/lib/gui/utils/config.py +++ b/lib/gui/utils/config.py @@ -1,19 +1,20 @@ #!/usr/bin python3 """ Global configuration optiopns for the Faceswap GUI """ +from __future__ import annotations import logging import os import sys import tkinter as tk +import typing as T from dataclasses import dataclass, field -from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING from lib.gui._config import Config as UserConfig from lib.gui.project import Project, Tasks from lib.gui.theme import Style from .file_handler import FileHandler -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.gui.options import CliOptions from lib.gui.custom_widgets import StatusBar from lib.gui.command import CommandNotebook @@ -22,12 +23,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # pylint: disable=invalid-name PATHCACHE = os.path.join(os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache") -_CONFIG: Optional["Config"] = None +_CONFIG: Config | None = None def initialize_config(root: tk.Tk, - cli_opts: Optional["CliOptions"], - statusbar: Optional["StatusBar"]) -> Optional["Config"]: + cli_opts: CliOptions | None, + statusbar: StatusBar | None) -> Config | None: """ Initialize the GUI Master :class:`Config` and add to global constant. This should only be called once on first GUI startup. Future access to :class:`Config` @@ -145,13 +146,13 @@ class GlobalVariables(): @dataclass class _GuiObjects: """ Data class for commonly accessed GUI Objects """ - cli_opts: Optional["CliOptions"] + cli_opts: CliOptions | None tk_vars: GlobalVariables project: Project tasks: Tasks - status_bar: Optional["StatusBar"] - default_options: Dict[str, Dict[str, Any]] = field(default_factory=dict) - command_notebook: Optional["CommandNotebook"] = None + status_bar: StatusBar | None + default_options: dict[str, dict[str, T.Any]] = field(default_factory=dict) + command_notebook: CommandNotebook | None = None class Config(): @@ -172,15 +173,15 @@ class Config(): """ def __init__(self, root: tk.Tk, - cli_opts: Optional["CliOptions"], - statusbar: Optional["StatusBar"]) -> None: + cli_opts: CliOptions | None, + statusbar: StatusBar | None) -> None: logger.debug("Initializing %s: (root %s, cli_opts: %s, statusbar: %s)", self.__class__.__name__, root, cli_opts, statusbar) - self._default_font = cast(dict, tk.font.nametofont("TkDefaultFont").configure())["family"] - self._constants = dict( - root=root, - scaling_factor=self._get_scaling(root), - default_font=self._default_font) + self._default_font = T.cast(dict, + tk.font.nametofont("TkDefaultFont").configure())["family"] + self._constants = {"root": root, + "scaling_factor": self._get_scaling(root), + "default_font": self._default_font} self._gui_objects = _GuiObjects( cli_opts=cli_opts, tk_vars=GlobalVariables(), @@ -211,7 +212,7 @@ class Config(): # GUI Objects @property - def cli_opts(self) -> "CliOptions": + def cli_opts(self) -> CliOptions: """ :class:`lib.gui.options.CliOptions`: The command line options for this GUI Session. """ # This should only be None when a separate tool (not main GUI) is used, at which point # cli_opts do not exist @@ -234,12 +235,12 @@ class Config(): return self._gui_objects.tasks @property - def default_options(self) -> Dict[str, Dict[str, Any]]: + def default_options(self) -> dict[str, dict[str, T.Any]]: """ dict: The default options for all tabs """ return self._gui_objects.default_options @property - def statusbar(self) -> "StatusBar": + def statusbar(self) -> StatusBar: """ :class:`lib.gui.custom_widgets.StatusBar`: The GUI StatusBar :class:`tkinter.ttk.Frame`. """ # This should only be None when a separate tool (not main GUI) is used, at which point @@ -248,31 +249,31 @@ class Config(): return self._gui_objects.status_bar @property - def command_notebook(self) -> Optional["CommandNotebook"]: + def command_notebook(self) -> CommandNotebook | None: """ :class:`lib.gui.command.CommandNotebook`: The main Faceswap Command Notebook. """ return self._gui_objects.command_notebook # Convenience GUI Objects @property - def tools_notebook(self) -> "ToolsNotebook": + def tools_notebook(self) -> ToolsNotebook: """ :class:`lib.gui.command.ToolsNotebook`: The Faceswap Tools sub-Notebook. """ assert self.command_notebook is not None return self.command_notebook.tools_notebook @property - def modified_vars(self) -> Dict[str, "tk.BooleanVar"]: + def modified_vars(self) -> dict[str, tk.BooleanVar]: """ dict: The command notebook modified tkinter variables. """ assert self.command_notebook is not None return self.command_notebook.modified_vars @property - def _command_tabs(self) -> Dict[str, int]: + def _command_tabs(self) -> dict[str, int]: """ dict: Command tab titles with their IDs. """ assert self.command_notebook is not None return self.command_notebook.tab_names @property - def _tools_tabs(self) -> Dict[str, int]: + def _tools_tabs(self) -> dict[str, int]: """ dict: Tools command tab titles with their IDs. """ assert self.command_notebook is not None return self.command_notebook.tools_tab_names @@ -284,17 +285,17 @@ class Config(): return self._user_config @property - def user_config_dict(self) -> Dict[str, Any]: # TODO Dataclass + def user_config_dict(self) -> dict[str, T.Any]: # TODO Dataclass """ dict: The GUI config in dict form. """ return self._user_config.config_dict @property - def user_theme(self) -> Dict[str, Any]: # TODO Dataclass + def user_theme(self) -> dict[str, T.Any]: # TODO Dataclass """ dict: The GUI theme selection options. """ return self._user_theme @property - def default_font(self) -> Tuple[str, int]: + def default_font(self) -> tuple[str, int]: """ tuple: The selected font as configured in user settings. First item is the font (`str`) second item the font size (`int`). """ font = self.user_config_dict["font"] @@ -328,7 +329,7 @@ class Config(): self._gui_objects.default_options = default self.project.set_default_options() - def set_command_notebook(self, notebook: "CommandNotebook") -> None: + def set_command_notebook(self, notebook: CommandNotebook) -> None: """ Set the command notebook to the :attr:`command_notebook` attribute and enable the modified callback for :attr:`project`. @@ -385,7 +386,7 @@ class Config(): """ Reload the user config from file. """ self._user_config = UserConfig(None) - def set_cursor_busy(self, widget: Optional[tk.Widget] = None) -> None: + def set_cursor_busy(self, widget: tk.Widget | None = None) -> None: """ Set the root or widget cursor to busy. Parameters @@ -399,7 +400,7 @@ class Config(): component.config(cursor="watch") # type: ignore component.update_idletasks() - def set_cursor_default(self, widget: Optional[tk.Widget] = None) -> None: + def set_cursor_default(self, widget: tk.Widget | None = None) -> None: """ Set the root or widget cursor to default. Parameters @@ -413,7 +414,7 @@ class Config(): component.config(cursor="") # type: ignore component.update_idletasks() - def set_root_title(self, text: Optional[str] = None) -> None: + def set_root_title(self, text: str | None = None) -> None: """ Set the main title text for Faceswap. The title will always begin with 'Faceswap.py'. Additional text can be appended. diff --git a/lib/gui/utils/file_handler.py b/lib/gui/utils/file_handler.py index 59485c5c..45ff8c5f 100644 --- a/lib/gui/utils/file_handler.py +++ b/lib/gui/utils/file_handler.py @@ -2,24 +2,15 @@ """ File browser utility functions for the Faceswap GUI. """ import logging import platform -import sys import tkinter as tk from tkinter import filedialog - -from typing import cast, Dict, IO, List, Optional, Tuple, Union - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - +import typing as T logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -_FILETYPE = Literal["default", "alignments", "config_project", "config_task", - "config_all", "csv", "image", "ini", "state", "log", "video"] -_HANDLETYPE = Literal["open", "save", "filename", "filename_multi", "save_filename", - "context", "dir"] +_FILETYPE = T.Literal["default", "alignments", "config_project", "config_task", + "config_all", "csv", "image", "ini", "state", "log", "video"] +_HANDLETYPE = T.Literal["open", "save", "filename", "filename_multi", "save_filename", + "context", "dir"] class FileHandler(): # pylint:disable=too-few-public-methods @@ -72,14 +63,14 @@ class FileHandler(): # pylint:disable=too-few-public-methods def __init__(self, handle_type: _HANDLETYPE, - file_type: Optional[_FILETYPE], - title: Optional[str] = None, - initial_folder: Optional[str] = None, - initial_file: Optional[str] = None, - command: Optional[str] = None, - action: Optional[str] = None, - variable: Optional[str] = None, - parent: Optional[tk.Frame] = None) -> None: + file_type: _FILETYPE | None, + title: str | None = None, + initial_folder: str | None = None, + initial_file: str | None = None, + command: str | None = None, + action: str | None = None, + variable: str | None = None, + parent: tk.Frame | None = None) -> None: logger.debug("Initializing %s: (handle_type: '%s', file_type: '%s', title: '%s', " "initial_folder: '%s', initial_file: '%s', command: '%s', action: '%s', " "variable: %s, parent: %s)", self.__class__.__name__, handle_type, file_type, @@ -101,35 +92,35 @@ class FileHandler(): # pylint:disable=too-few-public-methods logger.debug("Initialized %s", self.__class__.__name__) @property - def _filetypes(self) -> Dict[str, List[Tuple[str, str]]]: + def _filetypes(self) -> dict[str, list[tuple[str, str]]]: """ dict: The accepted extensions for each file type for opening/saving """ all_files = ("All files", "*.*") - filetypes = dict( - default=[all_files], - alignments=[("Faceswap Alignments", "*.fsa"), all_files], - config_project=[("Faceswap Project files", "*.fsw"), all_files], - config_task=[("Faceswap Task files", "*.fst"), all_files], - config_all=[("Faceswap Project and Task files", "*.fst *.fsw"), all_files], - csv=[("Comma separated values", "*.csv"), all_files], - image=[("Bitmap", "*.bmp"), - ("JPG", "*.jpeg *.jpg"), - ("PNG", "*.png"), - ("TIFF", "*.tif *.tiff"), - all_files], - ini=[("Faceswap config files", "*.ini"), all_files], - json=[("JSON file", "*.json"), all_files], - model=[("Keras model files", "*.h5"), all_files], - state=[("State files", "*.json"), all_files], - log=[("Log files", "*.log"), all_files], - video=[("Audio Video Interleave", "*.avi"), - ("Flash Video", "*.flv"), - ("Matroska", "*.mkv"), - ("MOV", "*.mov"), - ("MP4", "*.mp4"), - ("MPEG", "*.mpeg *.mpg *.ts *.vob"), - ("WebM", "*.webm"), - ("Windows Media Video", "*.wmv"), - all_files]) + filetypes = { + "default": [all_files], + "alignments": [("Faceswap Alignments", "*.fsa"), all_files], + "config_project": [("Faceswap Project files", "*.fsw"), all_files], + "config_task": [("Faceswap Task files", "*.fst"), all_files], + "config_all": [("Faceswap Project and Task files", "*.fst *.fsw"), all_files], + "csv": [("Comma separated values", "*.csv"), all_files], + "image": [("Bitmap", "*.bmp"), + ("JPG", "*.jpeg *.jpg"), + ("PNG", "*.png"), + ("TIFF", "*.tif *.tiff"), + all_files], + "ini": [("Faceswap config files", "*.ini"), all_files], + "json": [("JSON file", "*.json"), all_files], + "model": [("Keras model files", "*.h5"), all_files], + "state": [("State files", "*.json"), all_files], + "log": [("Log files", "*.log"), all_files], + "video": [("Audio Video Interleave", "*.avi"), + ("Flash Video", "*.flv"), + ("Matroska", "*.mkv"), + ("MOV", "*.mov"), + ("MP4", "*.mp4"), + ("MPEG", "*.mpeg *.mpg *.ts *.vob"), + ("WebM", "*.webm"), + ("Windows Media Video", "*.wmv"), + all_files]} # Add in multi-select options and upper case extensions for Linux for key in filetypes: @@ -142,32 +133,32 @@ class FileHandler(): # pylint:disable=too-few-public-methods multi = [f"{key.title()} Files"] multi.append(" ".join([ftype[1] for ftype in filetypes[key] if ftype[0] != "All files"])) - filetypes[key].insert(0, cast(Tuple[str, str], tuple(multi))) + filetypes[key].insert(0, T.cast(tuple[str, str], tuple(multi))) return filetypes @property - def _contexts(self) -> Dict[str, Dict[str, Union[str, Dict[str, str]]]]: + def _contexts(self) -> dict[str, dict[str, str | dict[str, str]]]: """dict: Mapping of commands, actions and their corresponding file dialog for context handle types. """ - return dict(effmpeg=dict(input={"extract": "filename", - "gen-vid": "dir", - "get-fps": "filename", - "get-info": "filename", - "mux-audio": "filename", - "rescale": "filename", - "rotate": "filename", - "slice": "filename"}, - output={"extract": "dir", - "gen-vid": "save_filename", - "get-fps": "nothing", - "get-info": "nothing", - "mux-audio": "save_filename", - "rescale": "save_filename", - "rotate": "save_filename", - "slice": "save_filename"})) + return {"effmpeg": {"input": {"extract": "filename", + "gen-vid": "dir", + "get-fps": "filename", + "get-info": "filename", + "mux-audio": "filename", + "rescale": "filename", + "rotate": "filename", + "slice": "filename"}, + "output": {"extract": "dir", + "gen-vid": "save_filename", + "get-fps": "nothing", + "get-info": "nothing", + "mux-audio": "save_filename", + "rescale": "save_filename", + "rotate": "save_filename", + "slice": "save_filename"}}} @classmethod - def _set_dummy_master(cls) -> Optional[tk.Frame]: + def _set_dummy_master(cls) -> tk.Frame | None: """ Add an option to force black font on Linux file dialogs KDE issue that displays light font on white background). @@ -183,7 +174,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods if platform.system().lower() == "linux": frame = tk.Frame() frame.option_add("*foreground", "black") - retval: Optional[tk.Frame] = frame + retval: tk.Frame | None = frame else: retval = None return retval @@ -196,7 +187,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods del self._dummy_master self._dummy_master = None - def _set_defaults(self) -> Dict[str, Optional[str]]: + def _set_defaults(self) -> dict[str, str | None]: """ Set the default file type for the file dialog. Generally the first found file type will be used, but this is overridden if it is not appropriate. @@ -205,7 +196,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods dict: The default file extension for each file type """ - defaults: Dict[str, Optional[str]] = { + defaults: dict[str, str | None] = { key: next(ext for ext in val[0][1].split(" ")).replace("*", "") for key, val in self._filetypes.items()} defaults["default"] = None @@ -215,15 +206,15 @@ class FileHandler(): # pylint:disable=too-few-public-methods return defaults def _set_kwargs(self, - title: Optional[str], - initial_folder: Optional[str], - initial_file: Optional[str], - file_type: Optional[_FILETYPE], - command: Optional[str], - action: Optional[str], - variable: Optional[str], - parent: Optional[tk.Frame] - ) -> Dict[str, Union[None, tk.Frame, str, List[Tuple[str, str]]]]: + title: str | None, + initial_folder: str | None, + initial_file: str | None, + file_type: _FILETYPE | None, + command: str | None, + action: str | None, + variable: str | None, + parent: tk.Frame | None + ) -> dict[str, None | tk.Frame | str | list[tuple[str, str]]]: """ Generate the required kwargs for the requested file dialog browser. Parameters @@ -259,8 +250,8 @@ class FileHandler(): # pylint:disable=too-few-public-methods title, initial_folder, initial_file, file_type, command, action, variable, parent) - kwargs: Dict[str, Union[None, tk.Frame, str, - List[Tuple[str, str]]]] = dict(master=self._dummy_master) + kwargs: dict[str, None | tk.Frame | str | list[tuple[str, str]]] = { + "master": self._dummy_master} if self._handletype.lower() == "context": assert command is not None and action is not None and variable is not None @@ -304,20 +295,20 @@ class FileHandler(): # pylint:disable=too-few-public-methods The variable associated with this file dialog """ if self._contexts[command].get(variable, None) is not None: - handletype = cast(Dict[str, Dict[str, Dict[str, str]]], - self._contexts)[command][variable][action] + handletype = T.cast(dict[str, dict[str, dict[str, str]]], + self._contexts)[command][variable][action] else: - handletype = cast(Dict[str, Dict[str, str]], - self._contexts)[command][action] + handletype = T.cast(dict[str, dict[str, str]], + self._contexts)[command][action] logger.debug(handletype) - self._handletype = cast(_HANDLETYPE, handletype) + self._handletype = T.cast(_HANDLETYPE, handletype) - def _open(self) -> Optional[IO]: + def _open(self) -> T.IO | None: """ Open a file. """ logger.debug("Popping Open browser") return filedialog.askopenfile(**self._kwargs) # type: ignore - def _save(self) -> Optional[IO]: + def _save(self) -> T.IO | None: """ Save a file. """ logger.debug("Popping Save browser") return filedialog.asksaveasfile(**self._kwargs) # type: ignore @@ -337,7 +328,7 @@ class FileHandler(): # pylint:disable=too-few-public-methods logger.debug("Popping Filename browser") return filedialog.askopenfilename(**self._kwargs) # type: ignore - def _filename_multi(self) -> Tuple[str, ...]: + def _filename_multi(self) -> tuple[str, ...]: """ Get multiple existing file locations. """ logger.debug("Popping Filename browser") return filedialog.askopenfilenames(**self._kwargs) # type: ignore diff --git a/lib/gui/utils/image.py b/lib/gui/utils/image.py index 31637744..37eb0528 100644 --- a/lib/gui/utils/image.py +++ b/lib/gui/utils/image.py @@ -1,10 +1,9 @@ #!/usr/bin python3 """ Utilities for handling images in the Faceswap GUI """ - +from __future__ import annotations import logging import os -import sys -from typing import cast, Dict, List, Optional, Sequence, Tuple +import typing as T import cv2 import numpy as np @@ -14,15 +13,12 @@ from lib.training.preview_cv import PreviewBuffer from .config import get_config, PATHCACHE -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal +if T.TYPE_CHECKING: + from collections.abc import Sequence logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -_IMAGES: Optional["Images"] = None -_PREVIEW_TRIGGER: Optional["PreviewTrigger"] = None +_IMAGES: "Images" | None = None +_PREVIEW_TRIGGER: "PreviewTrigger" | None = None TRAININGPREVIEW = ".gui_training_preview.png" @@ -51,7 +47,7 @@ def get_images() -> "Images": return _IMAGES -def _get_previews(image_path: str) -> List[str]: +def _get_previews(image_path: str) -> list[str]: """ Get the images stored within the given directory. Parameters @@ -164,12 +160,12 @@ class PreviewExtract(): self._output_path = "" self._modified: float = 0.0 - self._filenames: List[str] = [] - self._images: Optional[np.ndarray] = None - self._placeholder: Optional[np.ndarray] = None + self._filenames: list[str] = [] + self._images: np.ndarray | None = None + self._placeholder: np.ndarray | None = None - self._preview_image: Optional[Image.Image] = None - self._preview_image_tk: Optional[ImageTk.PhotoImage] = None + self._preview_image: Image.Image | None = None + self._preview_image_tk: ImageTk.PhotoImage | None = None logger.debug("Initialized %s", self.__class__.__name__) @@ -228,7 +224,7 @@ class PreviewExtract(): logger.debug("sorted folders: %s, return value: %s", folders, retval) return retval - def _get_newest_filenames(self, image_files: List[str]) -> List[str]: + def _get_newest_filenames(self, image_files: list[str]) -> list[str]: """ Return image filenames that have been modified since the last check. Parameters @@ -281,8 +277,8 @@ class PreviewExtract(): return retval def _process_samples(self, - samples: List[np.ndarray], - filenames: List[str], + samples: list[np.ndarray], + filenames: list[str], num_images: int) -> bool: """ Process the latest sample images into a displayable image. @@ -321,8 +317,8 @@ class PreviewExtract(): return True def _load_images_to_cache(self, - image_files: List[str], - frame_dims: Tuple[int, int], + image_files: list[str], + frame_dims: tuple[int, int], thumbnail_size: int) -> bool: """ Load preview images to the image cache. @@ -349,7 +345,7 @@ class PreviewExtract(): logger.debug("num_images: %s", num_images) if num_images == 0: return False - samples: List[np.ndarray] = [] + samples: list[np.ndarray] = [] start_idx = len(image_files) - num_images if len(image_files) > num_images else 0 show_files = sorted(image_files, key=os.path.getctime)[start_idx:] dropped_files = [] @@ -405,7 +401,7 @@ class PreviewExtract(): self._placeholder = placeholder logger.debug("Created placeholder. shape: %s", placeholder.shape) - def _place_previews(self, frame_dims: Tuple[int, int]) -> Image.Image: + def _place_previews(self, frame_dims: tuple[int, int]) -> Image.Image: """ Format the preview thumbnails stored in the cache into a grid fitting the display panel. @@ -441,12 +437,12 @@ class PreviewExtract(): placeholder = np.concatenate([np.expand_dims(self._placeholder, 0)] * remainder) samples = np.concatenate((samples, placeholder)) - display = np.vstack([np.hstack(cast(Sequence, samples[row * cols: (row + 1) * cols])) + display = np.vstack([np.hstack(T.cast("Sequence", samples[row * cols: (row + 1) * cols])) for row in range(rows)]) logger.debug("display shape: %s", display.shape) return Image.fromarray(display) - def load_latest_preview(self, thumbnail_size: int, frame_dims: Tuple[int, int]) -> bool: + def load_latest_preview(self, thumbnail_size: int, frame_dims: tuple[int, int]) -> bool: """ Load the latest preview image for extract and convert. Retrieves the latest preview images from the faceswap output folder, resizes to thumbnails @@ -524,7 +520,7 @@ class Images(): def __init__(self) -> None: logger.debug("Initializing %s", self.__class__.__name__) self._pathpreview = os.path.join(PATHCACHE, "preview") - self._pathoutput: Optional[str] = None + self._pathoutput: str | None = None self._batch_mode = False self._preview_train = PreviewTrain(self._pathpreview) self._preview_extract = PreviewExtract(self._pathpreview) @@ -542,7 +538,7 @@ class Images(): return self._preview_extract @property - def icons(self) -> Dict[str, ImageTk.PhotoImage]: + def icons(self) -> dict[str, ImageTk.PhotoImage]: """ dict: The faceswap icons for all parts of the GUI. The dictionary key is the icon name (`str`) the value is the icon sized and formatted for display (:class:`PIL.ImageTK.PhotoImage`). @@ -557,7 +553,7 @@ class Images(): return self._icons @staticmethod - def _load_icons() -> Dict[str, ImageTk.PhotoImage]: + def _load_icons() -> dict[str, ImageTk.PhotoImage]: """ Scan the icons cache folder and load the icons into :attr:`icons` for retrieval throughout the GUI. @@ -569,7 +565,7 @@ class Images(): """ size = get_config().user_config_dict.get("icon_size", 16) size = int(round(size * get_config().scaling_factor)) - icons: Dict[str, ImageTk.PhotoImage] = {} + icons: dict[str, ImageTk.PhotoImage] = {} pathicons = os.path.join(PATHCACHE, "icons") for fname in os.listdir(pathicons): name, ext = os.path.splitext(fname) @@ -609,12 +605,12 @@ class PreviewTrigger(): """ def __init__(self) -> None: logger.debug("Initializing: %s", self.__class__.__name__) - self._trigger_files = dict(update=os.path.join(PATHCACHE, ".preview_trigger"), - mask_toggle=os.path.join(PATHCACHE, ".preview_mask_toggle")) + self._trigger_files = {"update": os.path.join(PATHCACHE, ".preview_trigger"), + "mask_toggle": os.path.join(PATHCACHE, ".preview_mask_toggle")} logger.debug("Initialized: %s (trigger_files: %s)", self.__class__.__name__, self._trigger_files) - def set(self, trigger_type: Literal["update", "mask_toggle"]): + def set(self, trigger_type: T.Literal["update", "mask_toggle"]): """ Place the trigger file into the cache folder Parameters @@ -629,7 +625,7 @@ class PreviewTrigger(): pass logger.debug("Set preview trigger: %s", trigger) - def clear(self, trigger_type: Optional[Literal["update", "mask_toggle"]] = None) -> None: + def clear(self, trigger_type: T.Literal["update", "mask_toggle"] | None = None) -> None: """ Remove the trigger file from the cache folder. Parameters diff --git a/lib/gui/utils/misc.py b/lib/gui/utils/misc.py index 52a6d4e8..2506af37 100644 --- a/lib/gui/utils/misc.py +++ b/lib/gui/utils/misc.py @@ -1,15 +1,17 @@ #!/usr/bin/env python3 """ Miscellaneous Utility functions for the GUI. Includes LongRunningTask object """ +from __future__ import annotations import logging import sys +import typing as T from threading import Event, Thread -from typing import (Any, Callable, cast, Dict, Optional, Tuple, Type, TYPE_CHECKING) from queue import Queue from .config import get_config -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Callable from types import TracebackType from lib.multithreading import _ErrorType @@ -31,15 +33,15 @@ class LongRunningTask(Thread): cursor in the correct location. Default: ``None``. """ _target: Callable - _args: Tuple - _kwargs: Dict[str, Any] + _args: tuple + _kwargs: dict[str, T.Any] _name: str def __init__(self, - target: Optional[Callable] = None, - name: Optional[str] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None, + target: Callable | None = None, + name: str | None = None, + args: tuple = (), + kwargs: dict[str, T.Any] | None = None, *, daemon: bool = True, widget=None): @@ -48,7 +50,7 @@ class LongRunningTask(Thread): daemon) super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) - self.err: "_ErrorType" = None + self.err: _ErrorType = None self._widget = widget self._config = get_config() self._config.set_cursor_busy(widget=self._widget) @@ -70,8 +72,8 @@ class LongRunningTask(Thread): retval = self._target(*self._args, **self._kwargs) self._queue.put(retval) except Exception: # pylint: disable=broad-except - self.err = cast(Tuple[Type[BaseException], BaseException, "TracebackType"], - sys.exc_info()) + self.err = T.cast(tuple[type[BaseException], BaseException, "TracebackType"], + sys.exc_info()) assert self.err is not None logger.debug("Error in thread (%s): %s", self._name, self.err[1].with_traceback(self.err[2])) @@ -81,7 +83,7 @@ class LongRunningTask(Thread): # an argument that has a member that points to the thread. del self._target, self._args, self._kwargs - def get_result(self) -> Any: + def get_result(self) -> T.Any: """ Return the result from the given task. Returns diff --git a/lib/image.py b/lib/image.py index 06c13c5d..8713335e 100644 --- a/lib/image.py +++ b/lib/image.py @@ -1,17 +1,17 @@ #!/usr/bin python3 """ Utilities for working with images and videos """ - +from __future__ import annotations import logging import re import subprocess import os import struct import sys +import typing as T from ast import literal_eval from bisect import bisect from concurrent import futures -from typing import Optional, TYPE_CHECKING, Union from zlib import crc32 import cv2 @@ -24,7 +24,7 @@ from lib.multithreading import MultiThread from lib.queue_manager import queue_manager, QueueEmpty from lib.utils import convert_to_secs, FaceswapError, _video_extensions, get_image_paths -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.alignments import PNGHeaderDict logger = logging.getLogger(__name__) # pylint:disable=invalid-name @@ -558,7 +558,7 @@ def update_existing_metadata(filename, metadata): def encode_image(image: np.ndarray, extension: str, - metadata: Optional["PNGHeaderDict"] = None) -> bytes: + metadata: PNGHeaderDict | None = None) -> bytes: """ Encode an image. Parameters @@ -1433,8 +1433,8 @@ class ImagesSaver(ImageIO): def _save(self, filename: str, - image: Union[bytes, np.ndarray], - sub_folder: Optional[str]) -> None: + image: bytes | np.ndarray, + sub_folder: str | None) -> None: """ Save a single image inside a ThreadPoolExecutor Parameters @@ -1468,8 +1468,8 @@ class ImagesSaver(ImageIO): def save(self, filename: str, - image: Union[bytes, np.ndarray], - sub_folder: Optional[str] = None) -> None: + image: bytes | np.ndarray, + sub_folder: str | None = None) -> None: """ Save the given image in the background thread Ensure that :func:`close` is called once all save operations are complete. diff --git a/lib/keras_utils.py b/lib/keras_utils.py index 0f001962..9f278986 100644 --- a/lib/keras_utils.py +++ b/lib/keras_utils.py @@ -117,7 +117,7 @@ class ColorSpaceConvert(): # pylint:disable=too-few-public-methods self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32") @classmethod - def _get_rgb_xyz_map(cls) -> T.Tuple[Tensor, Tensor]: + def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]: """ Obtain the mapping and inverse mapping for rgb to xyz color space conversion. Returns diff --git a/lib/logger.py b/lib/logger.py index c7b34f5d..623af04c 100644 --- a/lib/logger.py +++ b/lib/logger.py @@ -11,7 +11,17 @@ import time import traceback from datetime import datetime -from typing import Union + + +# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib +def _patched_format(self, record): + """ Autograph tf-2.10 has a bug with the 3.10 version of logging.PercentStyle._format(). It is + non-critical but spits out warnings. This is the Python 3.9 version of the function and should + be removed once fixed """ + return self._fmt % record.__dict__ # pylint:disable=protected-access + + +setattr(logging.PercentStyle, "_format", _patched_format) class FaceswapLogger(logging.Logger): @@ -76,11 +86,11 @@ class ColoredFormatter(logging.Formatter): def __init__(self, fmt: str, pad_newlines: bool = False, **kwargs) -> None: super().__init__(fmt, **kwargs) self._use_color = self._get_color_compatibility() - self._level_colors = dict(CRITICAL="\033[31m", # red - ERROR="\033[31m", # red - WARNING="\033[33m", # yellow - INFO="\033[32m", # green - VERBOSE="\033[34m") # blue + self._level_colors = {"CRITICAL": "\033[31m", # red + "ERROR": "\033[31m", # red + "WARNING": "\033[33m", # yellow + "INFO": "\033[32m", # green + "VERBOSE": "\033[34m"} # blue self._default_color = "\033[0m" self._newline_padding = self._get_newline_padding(pad_newlines, fmt) @@ -412,7 +422,7 @@ def _file_handler(loglevel, return handler -def _stream_handler(loglevel: int, is_gui: bool) -> Union[logging.StreamHandler, TqdmHandler]: +def _stream_handler(loglevel: int, is_gui: bool) -> logging.StreamHandler | TqdmHandler: """ Add a stream handler for the current Faceswap session. The stream handler will only ever output at a maximum of VERBOSE level to avoid spamming the console. diff --git a/lib/model/autoclip.py b/lib/model/autoclip.py index e1750ae2..a9ccfe88 100644 --- a/lib/model/autoclip.py +++ b/lib/model/autoclip.py @@ -1,8 +1,6 @@ """ Auto clipper for clipping gradients. """ -from typing import List - +import numpy as np import tensorflow as tf -import tensorflow_probability as tfp class AutoClipper(): # pylint:disable=too-few-public-methods @@ -22,12 +20,56 @@ class AutoClipper(): # pylint:disable=too-few-public-methods original paper: https://arxiv.org/abs/2007.14469 """ def __init__(self, clip_percentile: int, history_size: int = 10000): - self._clip_percentile = clip_percentile + self._clip_percentile = tf.cast(clip_percentile, tf.float64) self._grad_history = tf.Variable(tf.zeros(history_size), trainable=False) self._index = tf.Variable(0, trainable=False) self._history_size = history_size - def __call__(self, grads_and_vars: List[tf.Tensor]) -> List[tf.Tensor]: + def _percentile(self, grad_history: tf.Tensor) -> tf.Tensor: + """ Compute the clip percentile of the gradient history + + Parameters + ---------- + grad_history: :class:`tensorflow.Tensor` + Tge gradient history to calculate the clip percentile for + + Returns + ------- + :class:`tensorflow.Tensor` + A rank(:attr:`clip_percentile`) `Tensor` + + Notes + ----- + Adapted from + https://github.com/tensorflow/probability/blob/r0.14/tensorflow_probability/python/stats/quantiles.py + to remove reliance on full tensorflow_probability libraray + """ + with tf.name_scope("percentile"): + frac_at_q_or_below = self._clip_percentile / 100. + sorted_hist = tf.sort(grad_history, axis=-1, direction="ASCENDING") + + num = tf.cast(tf.shape(grad_history)[-1], tf.float64) + + # get indices + indices = tf.round((num - 1) * frac_at_q_or_below) + indices = tf.clip_by_value(tf.cast(indices, tf.int32), + 0, + tf.shape(grad_history)[-1] - 1) + gathered_hist = tf.gather(sorted_hist, indices, axis=-1) + + # Propagate NaNs. Apparently tf.is_nan doesn't like other dtypes + nan_batch_members = tf.reduce_any(tf.math.is_nan(grad_history), axis=None) + right_rank_matched_shape = tf.pad(tf.shape(nan_batch_members), + paddings=[[0, tf.rank(self._clip_percentile)]], + constant_values=1) + nan_batch_members = tf.reshape(nan_batch_members, shape=right_rank_matched_shape) + + nan = np.array(np.nan, gathered_hist.dtype.as_numpy_dtype) + gathered_hist = tf.where(nan_batch_members, nan, gathered_hist) + + return gathered_hist + + def __call__(self, grads_and_vars: list[tf.Tensor]) -> list[tf.Tensor]: """ Call the AutoClip function. Parameters @@ -40,8 +82,7 @@ class AutoClipper(): # pylint:disable=too-few-public-methods assign_idx = tf.math.mod(self._index, self._history_size) self._grad_history = self._grad_history[assign_idx].assign(total_norm) self._index = self._index.assign_add(1) - clip_value = tfp.stats.percentile(self._grad_history[: self._index], - q=self._clip_percentile) + clip_value = self._percentile(self._grad_history[: self._index]) return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars] @classmethod diff --git a/lib/model/losses/feature_loss.py b/lib/model/losses/feature_loss.py index 80d206b7..9898ae1b 100644 --- a/lib/model/losses/feature_loss.py +++ b/lib/model/losses/feature_loss.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ Custom Feature Map Loss Functions for faceswap.py """ +from __future__ import annotations from dataclasses import dataclass, field import logging - -from typing import Any, Callable, Dict, Optional, List, Tuple +import typing as T # Ignore linting errors from Tensorflow's thoroughly broken import system import tensorflow as tf @@ -17,6 +17,9 @@ import numpy as np from lib.model.nets import AlexNet, SqueezeNet from lib.utils import GetModel +if T.TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) @@ -39,10 +42,10 @@ class NetInfo: """ model_id: int = 0 model_name: str = "" - net: Optional[Callable] = None - init_kwargs: Dict[str, Any] = field(default_factory=dict) + net: Callable | None = None + init_kwargs: dict[str, T.Any] = field(default_factory=dict) needs_init: bool = True - outputs: List[Layer] = field(default_factory=list) + outputs: list[Layer] = field(default_factory=list) class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods @@ -67,7 +70,7 @@ class _LPIPSTrunkNet(): # pylint:disable=too-few-public-methods logger.debug("Initialized: %s ", self.__class__.__name__) @property - def _nets(self) -> Dict[str, NetInfo]: + def _nets(self) -> dict[str, NetInfo]: """ :class:`NetInfo`: The Information about the requested net.""" return { "alex": NetInfo(model_id=15, @@ -176,7 +179,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods logger.debug("Initialized: %s", self.__class__.__name__) @property - def _nets(self) -> Dict[str, NetInfo]: + def _nets(self) -> dict[str, NetInfo]: """ :class:`NetInfo`: The Information about the requested net.""" return { "alex": NetInfo(model_id=18, @@ -186,7 +189,7 @@ class _LPIPSLinearNet(_LPIPSTrunkNet): # pylint:disable=too-few-public-methods "vgg16": NetInfo(model_id=20, model_name="vgg16_lpips_v1.h5")} - def _linear_block(self, net_output_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + def _linear_block(self, net_output_layer: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]: """ Build a linear block for a trunk network output. Parameters @@ -319,7 +322,7 @@ class LPIPSLoss(): # pylint:disable=too-few-public-methods tf.keras.mixed_precision.set_global_policy("mixed_float16") logger.debug("Initialized: %s", self.__class__.__name__) - def _process_diffs(self, inputs: List[tf.Tensor]) -> List[tf.Tensor]: + def _process_diffs(self, inputs: list[tf.Tensor]) -> list[tf.Tensor]: """ Perform processing on the Trunk Network outputs. If :attr:`use_ldip` is enabled, process the diff values through the linear network, diff --git a/lib/model/losses/loss.py b/lib/model/losses/loss.py index 7e182814..ab03ff53 100644 --- a/lib/model/losses/loss.py +++ b/lib/model/losses/loss.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """ Custom Loss Functions for faceswap.py """ -from __future__ import absolute_import - +from __future__ import annotations import logging -from typing import Callable, List, Tuple +import typing as T import numpy as np import tensorflow as tf @@ -13,6 +12,9 @@ import tensorflow as tf from tensorflow.python.keras.engine import compile_utils # pylint:disable=no-name-in-module from tensorflow.keras import backend as K # pylint:disable=import-error +if T.TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) @@ -61,7 +63,7 @@ class FocalFrequencyLoss(): # pylint:disable=too-few-public-methods self._ave_spectrum = ave_spectrum self._log_matrix = log_matrix self._batch_matrix = batch_matrix - self._dims: Tuple[int, int] = (0, 0) + self._dims: tuple[int, int] = (0, 0) def _get_patches(self, inputs: tf.Tensor) -> tf.Tensor: """ Crop the incoming batch of images into patches as defined by :attr:`_patch_factor. @@ -470,7 +472,7 @@ class LaplacianPyramidLoss(): # pylint:disable=too-few-public-methods retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid") return retval - def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> List[tf.Tensor]: + def _get_laplacian_pyramid(self, inputs: tf.Tensor) -> list[tf.Tensor]: """ Obtain the Laplacian Pyramid. Parameters @@ -564,9 +566,9 @@ class LossWrapper(tf.keras.losses.Loss): def __init__(self) -> None: logger.debug("Initializing: %s", self.__class__.__name__) super().__init__(name="LossWrapper") - self._loss_functions: List[compile_utils.LossesContainer] = [] - self._loss_weights: List[float] = [] - self._mask_channels: List[int] = [] + self._loss_functions: list[compile_utils.LossesContainer] = [] + self._loss_weights: list[float] = [] + self._mask_channels: list[int] = [] logger.debug("Initialized: %s", self.__class__.__name__) def add_loss(self, @@ -628,7 +630,7 @@ class LossWrapper(tf.keras.losses.Loss): y_true: tf.Tensor, y_pred: tf.Tensor, mask_channel: int, - mask_prop: float = 1.0) -> Tuple[tf.Tensor, tf.Tensor]: + mask_prop: float = 1.0) -> tuple[tf.Tensor, tf.Tensor]: """ Apply the mask to the input y_true and y_pred. If a mask is not required then return the unmasked inputs. diff --git a/lib/model/losses/perceptual_loss.py b/lib/model/losses/perceptual_loss.py index ccb2cfa4..0fc09b81 100644 --- a/lib/model/losses/perceptual_loss.py +++ b/lib/model/losses/perceptual_loss.py @@ -2,9 +2,7 @@ """ TF Keras implementation of Perceptual Loss Functions for faceswap.py """ import logging -import sys - -from typing import Dict, Optional, Tuple +import typing as T import numpy as np import tensorflow as tf @@ -14,11 +12,6 @@ from tensorflow.keras import backend as K # pylint:disable=import-error from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - logger = logging.getLogger(__name__) @@ -101,7 +94,7 @@ class DSSIMObjective(): # pylint:disable=too-few-public-methods """ return K.depthwise_conv2d(image, kernel, strides=(1, 1), padding="valid") - def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + def _get_ssim(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]: """ Obtain the structural similarity between a batch of true and predicted images. Parameters @@ -330,8 +323,8 @@ class LDRFLIPLoss(): # pylint:disable=too-few-public-methods lower_threshold_exponent: float = 0.4, upper_threshold_exponent: float = 0.95, epsilon: float = 1e-15, - pixels_per_degree: Optional[float] = None, - color_order: Literal["bgr", "rgb"] = "bgr") -> None: + pixels_per_degree: float | None = None, + color_order: T.Literal["bgr", "rgb"] = "bgr") -> None: logger.debug("Initializing: %s (computed_distance_exponent '%s', feature_exponent: %s, " "lower_threshold_exponent: %s, upper_threshold_exponent: %s, epsilon: %s, " "pixels_per_degree: %s, color_order: %s)", self.__class__.__name__, @@ -525,7 +518,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods self._spatial_filters, self._radius = self._generate_spatial_filters() self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb") - def _generate_spatial_filters(self) -> Tuple[tf.Tensor, int]: + def _generate_spatial_filters(self) -> tuple[tf.Tensor, int]: """ Generates spatial contrast sensitivity filters with width depending on the number of pixels per degree of visual angle of the observer for channels "A", "RG" and "BY" @@ -559,7 +552,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods b1_rg: float, b2_rg: float, b1_by: float, - b2_by: float) -> Tuple[np.ndarray, int]: + b2_by: float) -> tuple[np.ndarray, int]: """ TODO docstring """ max_scale_parameter = max([b1_a, b2_a, b1_rg, b2_rg, b1_by, b2_by]) delta_x = 1.0 / self._pixels_per_degree @@ -570,7 +563,7 @@ class _SpatialFilters(): # pylint:disable=too-few-public-methods return domain, radius @classmethod - def _generate_weights(cls, channel: Dict[str, float], domain: np.ndarray) -> tf.Tensor: + def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> tf.Tensor: """ TODO docstring """ a_1, b_1, a_2, b_2 = channel["a1"], channel["b1"], channel["a2"], channel["b2"] grad = (a_1 * np.sqrt(np.pi / b_1) * np.exp(-np.pi ** 2 * domain / b_1) + @@ -694,7 +687,7 @@ class MSSIMLoss(): # pylint:disable=too-few-public-methods filter_size: int = 11, filter_sigma: float = 1.5, max_value: float = 1.0, - power_factors: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + power_factors: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) ) -> None: self.filter_size = filter_size self.filter_sigma = filter_sigma diff --git a/lib/model/nets.py b/lib/model/nets.py index 87e740b7..4fa8294d 100644 --- a/lib/model/nets.py +++ b/lib/model/nets.py @@ -31,7 +31,7 @@ class _net(): # pylint:disable=too-few-public-methods The input shape for the model. Default: ``None`` """ def __init__(self, - input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None: + input_shape: tuple[int, int, int] | None = None) -> None: logger.debug("Initializing: %s (input_shape: %s)", self.__class__.__name__, input_shape) self._input_shape = (None, None, 3) if input_shape is None else input_shape assert len(self._input_shape) == 3 and self._input_shape[-1] == 3, ( @@ -56,7 +56,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods input_shape, Tuple, optional The input shape for the model. Default: ``None`` """ - def __init__(self, input_shape: T.Optional[T.Tuple[int, int, int]] = None) -> None: + def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None: super().__init__(input_shape) self._feature_indices = [0, 3, 6, 8, 10] # For naming equivalent to PyTorch self._filters = [64, 192, 384, 256, 256] # Filters at each block @@ -108,7 +108,7 @@ class AlexNet(_net): # pylint:disable=too-few-public-methods name=name)(var_x) return var_x - def __call__(self) -> Model: + def __call__(self) -> tf.keras.models.Model: """ Create the AlexNet Model Returns @@ -189,7 +189,7 @@ class SqueezeNet(_net): # pylint:disable=too-few-public-methods name=f"{name}.expand3x3")(squeezed) return layers.Concatenate(axis=-1, name=name)([expand1, expand3]) - def __call__(self) -> Model: + def __call__(self) -> tf.keras.models.Model: """ Create the SqueezeNet Model Returns diff --git a/lib/model/nn_blocks.py b/lib/model/nn_blocks.py index 96cc0e8f..9ffa1702 100644 --- a/lib/model/nn_blocks.py +++ b/lib/model/nn_blocks.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name _CONFIG: dict = {} -_NAMES: T.Dict[str, int] = {} +_NAMES: dict[str, int] = {} def set_config(configuration: dict) -> None: @@ -189,7 +189,7 @@ class Conv2DOutput(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int]], + kernel_size: int | tuple[int], activation: str = "sigmoid", padding: str = "same", **kwargs) -> None: self._name = kwargs.pop("name") if "name" in kwargs else _get_name( @@ -265,11 +265,11 @@ class Conv2DBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 5, - strides: T.Union[int, T.Tuple[int, int]] = 2, + kernel_size: int | tuple[int, int] = 5, + strides: int | tuple[int, int] = 2, padding: str = "same", - normalization: T.Optional[str] = None, - activation: T.Optional[str] = "leakyrelu", + normalization: str | None = None, + activation: str | None = "leakyrelu", use_depthwise: bool = False, relu_alpha: float = 0.1, **kwargs) -> None: @@ -362,8 +362,8 @@ class SeparableConv2DBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 5, - strides: T.Union[int, T.Tuple[int, int]] = 2, **kwargs) -> None: + kernel_size: int | tuple[int, int] = 5, + strides: int | tuple[int, int] = 2, **kwargs) -> None: self._name = _get_name(f"separableconv2d_{filters}") logger.debug("name: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s)", self._name, filters, kernel_size, strides, kwargs) @@ -434,11 +434,11 @@ class UpscaleBlock(): # pylint:disable=too-few-public-methods def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 3, + kernel_size: int | tuple[int, int] = 3, padding: str = "same", scale_factor: int = 2, - normalization: T.Optional[str] = None, - activation: T.Optional[str] = "leakyrelu", + normalization: str | None = None, + activation: str | None = "leakyrelu", **kwargs) -> None: self._name = _get_name(f"upscale_{filters}") logger.debug("name: %s. filters: %s, kernel_size: %s, padding: %s, scale_factor: %s, " @@ -521,9 +521,9 @@ class Upscale2xBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 3, + kernel_size: int | tuple[int, int] = 3, padding: str = "same", - activation: T.Optional[str] = "leakyrelu", + activation: str | None = "leakyrelu", interpolation: str = "bilinear", sr_ratio: float = 0.5, scale_factor: int = 2, @@ -615,9 +615,9 @@ class UpscaleResizeImagesBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 3, + kernel_size: int | tuple[int, int] = 3, padding: str = "same", - activation: T.Optional[str] = "leakyrelu", + activation: str | None = "leakyrelu", scale_factor: int = 2, interpolation: str = "bilinear") -> None: self._name = _get_name(f"upscale_ri_{filters}") @@ -700,9 +700,9 @@ class UpscaleDNYBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 3, + kernel_size: int | tuple[int, int] = 3, padding: str = "same", - activation: T.Optional[str] = "leakyrelu", + activation: str | None = "leakyrelu", size: int = 2, interpolation: str = "bilinear", **kwargs) -> None: @@ -757,7 +757,7 @@ class ResidualBlock(): # pylint:disable=too-few-public-methods """ def __init__(self, filters: int, - kernel_size: T.Union[int, T.Tuple[int, int]] = 3, + kernel_size: int | tuple[int, int] = 3, padding: str = "same", **kwargs) -> None: self._name = _get_name(f"residual_{filters}") diff --git a/lib/model/session.py b/lib/model/session.py index a7450d74..9cc1de13 100644 --- a/lib/model/session.py +++ b/lib/model/session.py @@ -1,9 +1,9 @@ #!/usr/bin python3 """ Settings manager for Keras Backend """ - +from __future__ import annotations from contextlib import nullcontext import logging -from typing import Callable, ContextManager, List, Optional, Union +import typing as T import numpy as np import tensorflow as tf @@ -14,6 +14,9 @@ from tensorflow.keras.models import load_model as k_load_model, Model # noqa:E5 from lib.utils import get_backend +if T.TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) # pylint:disable=invalid-name @@ -52,9 +55,9 @@ class KSession(): def __init__(self, name: str, model_path: str, - model_kwargs: Optional[dict] = None, + model_kwargs: dict | None = None, allow_growth: bool = False, - exclude_gpus: Optional[List[int]] = None, + exclude_gpus: list[int] | None = None, cpu_mode: bool = False) -> None: logger.trace("Initializing: %s (name: %s, model_path: %s, " # type:ignore "model_kwargs: %s, allow_growth: %s, exclude_gpus: %s, cpu_mode: %s)", @@ -67,12 +70,12 @@ class KSession(): cpu_mode) self._model_path = model_path self._model_kwargs = {} if not model_kwargs else model_kwargs - self._model: Optional[Model] = None + self._model: Model | None = None logger.trace("Initialized: %s", self.__class__.__name__,) # type:ignore def predict(self, - feed: Union[List[np.ndarray], np.ndarray], - batch_size: Optional[int] = None) -> Union[List[np.ndarray], np.ndarray]: + feed: list[np.ndarray] | np.ndarray, + batch_size: int | None = None) -> list[np.ndarray] | np.ndarray: """ Get predictions from the model. This method is a wrapper for :func:`keras.predict()` function. For Tensorflow backends @@ -98,7 +101,7 @@ class KSession(): def _set_session(self, allow_growth: bool, exclude_gpus: list, - cpu_mode: bool) -> ContextManager: + cpu_mode: bool) -> T.ContextManager: """ Sets the backend session options. For CPU backends, this hides any GPUs from Tensorflow. diff --git a/lib/multithreading.py b/lib/multithreading.py index e85685a8..88599aae 100644 --- a/lib/multithreading.py +++ b/lib/multithreading.py @@ -1,19 +1,22 @@ #!/usr/bin/env python3 """ Multithreading/processing utils for faceswap """ - +from __future__ import annotations import logging +import typing as T from multiprocessing import cpu_count import queue as Queue import sys import threading from types import TracebackType -from typing import Any, Callable, Dict, Generator, List, Tuple, Type, Optional, Set, Union + +if T.TYPE_CHECKING: + from collections.abc import Callable, Generator logger = logging.getLogger(__name__) # pylint: disable=invalid-name -_ErrorType = Optional[Union[Tuple[Type[BaseException], BaseException, TracebackType], - Tuple[Any, Any, Any]]] -_THREAD_NAMES: Set[str] = set() +_ErrorType: T.TypeAlias = (tuple[type[BaseException], BaseException, TracebackType] | + tuple[T.Any, T.Any, T.Any] | None) +_THREAD_NAMES: set[str] = set() def total_cpus(): @@ -62,17 +65,17 @@ class FSThread(threading.Thread): keyword arguments for the target invocation. Default: {}. """ _target: Callable - _args: Tuple - _kwargs: Dict[str, Any] + _args: tuple + _kwargs: dict[str, T.Any] _name: str def __init__(self, - target: Optional[Callable] = None, - name: Optional[str] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None, + target: Callable | None = None, + name: str | None = None, + args: tuple = (), + kwargs: dict[str, T.Any] | None = None, *, - daemon: Optional[bool] = None) -> None: + daemon: bool | None = None) -> None: super().__init__(target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) self.err: _ErrorType = None @@ -124,7 +127,7 @@ class MultiThread(): target: Callable, *args, thread_count: int = 1, - name: Optional[str] = None, + name: str | None = None, **kwargs) -> None: self._name = _get_name(name if name else target.__name__) logger.debug("Initializing %s: (target: '%s', thread_count: %s)", @@ -132,7 +135,7 @@ class MultiThread(): logger.trace("args: %s, kwargs: %s", args, kwargs) # type:ignore self.daemon = True self._thread_count = thread_count - self._threads: List[FSThread] = [] + self._threads: list[FSThread] = [] self._target = target self._args = args self._kwargs = kwargs @@ -144,7 +147,7 @@ class MultiThread(): return any(thread.err for thread in self._threads) @property - def errors(self) -> List[_ErrorType]: + def errors(self) -> list[_ErrorType]: """ list: List of thread error values """ return [thread.err for thread in self._threads if thread.err] @@ -253,9 +256,9 @@ class BackgroundGenerator(MultiThread): def __init__(self, generator: Callable, prefetch: int = 1, - name: Optional[str] = None, - args: Optional[Tuple] = None, - kwargs: Optional[Dict[str, Any]] = None) -> None: + name: str | None = None, + args: tuple | None = None, + kwargs: dict[str, T.Any] | None = None) -> None: super().__init__(name=name, target=self._run) self.queue: Queue.Queue = Queue.Queue(prefetch) self.generator = generator diff --git a/lib/queue_manager.py b/lib/queue_manager.py index d34dd5fa..7eeacc14 100644 --- a/lib/queue_manager.py +++ b/lib/queue_manager.py @@ -6,7 +6,6 @@ import logging import threading -from typing import Dict from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa from time import sleep @@ -45,7 +44,7 @@ class _QueueManager(): logger.debug("Initializing %s", self.__class__.__name__) self.shutdown = threading.Event() - self.queues: Dict[str, EventQueue] = {} + self.queues: dict[str, EventQueue] = {} logger.debug("Initialized %s", self.__class__.__name__) def add_queue(self, name: str, maxsize: int = 0, create_new: bool = False) -> str: diff --git a/lib/sysinfo.py b/lib/sysinfo.py index 4dbfde26..6d40d178 100644 --- a/lib/sysinfo.py +++ b/lib/sysinfo.py @@ -6,8 +6,8 @@ import locale import os import platform import sys + from subprocess import PIPE, Popen -from typing import List, Optional import psutil @@ -21,14 +21,14 @@ class _SysInfo(): # pylint:disable=too-few-public-methods def __init__(self) -> None: self._state_file = _State().state_file self._configs = _Configs().configs - self._system = dict(platform=platform.platform(), - system=platform.system().lower(), - machine=platform.machine(), - release=platform.release(), - processor=platform.processor(), - cpu_count=os.cpu_count()) - self._python = dict(implementation=platform.python_implementation(), - version=platform.python_version()) + self._system = {"platform": platform.platform(), + "system": platform.system().lower(), + "machine": platform.machine(), + "release": platform.release(), + "processor": platform.processor(), + "cpu_count": os.cpu_count()} + self._python = {"implementation": platform.python_implementation(), + "version": platform.python_version()} self._gpu = self._get_gpu_info() self._cuda_check = CudaCheck() @@ -66,7 +66,7 @@ class _SysInfo(): # pylint:disable=too-few-public-methods (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)) else: prefix = os.path.dirname(sys.prefix) - retval = (os.path.basename(prefix) == "envs") + retval = os.path.basename(prefix) == "envs" return retval @property @@ -295,7 +295,7 @@ class _Configs(): # pylint:disable=too-few-public-methods except FileNotFoundError: return "" - def _parse_configs(self, config_files: List[str]) -> str: + def _parse_configs(self, config_files: list[str]) -> str: """ Parse the given list of config files into a human readable format. Parameters @@ -399,7 +399,7 @@ class _State(): # pylint:disable=too-few-public-methods return len(sys.argv) > 1 and sys.argv[1].lower() == "train" @staticmethod - def _get_arg(*args: str) -> Optional[str]: + def _get_arg(*args: str) -> str | None: """ Obtain the value for a given command line option from sys.argv. Returns diff --git a/lib/training/__init__.py b/lib/training/__init__.py index 29905793..e35d3e19 100644 --- a/lib/training/__init__.py +++ b/lib/training/__init__.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 """ Package for handling alignments files, detected faces and aligned faces along with their associated objects. """ - -from typing import Type, TYPE_CHECKING +from __future__ import annotations +import typing as T from .augmentation import ImageAugmentation from .generator import PreviewDataGenerator, TrainingDataGenerator from .preview_cv import PreviewBuffer, TriggerType -if TYPE_CHECKING: +if T.TYPE_CHECKING: from .preview_cv import PreviewBase - Preview: Type[PreviewBase] + Preview: type[PreviewBase] try: from .preview_tk import PreviewTk as Preview diff --git a/lib/training/augmentation.py b/lib/training/augmentation.py index 6fff6e08..c559a621 100644 --- a/lib/training/augmentation.py +++ b/lib/training/augmentation.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 """ Processes the augmentation of images for feeding into a Faceswap model. """ +from __future__ import annotations from dataclasses import dataclass import logging -from typing import Dict, Tuple, TYPE_CHECKING +import typing as T import cv2 import numexpr as ne @@ -11,7 +12,7 @@ from scipy.interpolate import griddata from lib.image import batch_convert_color -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.config import ConfigValueType logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -56,7 +57,7 @@ class AugConstants: transform_zoom: float transform_shift: float warp_maps: np.ndarray - warp_pad: Tuple[int, int] + warp_pad: tuple[int, int] warp_slices: slice warp_lm_edge_anchors: np.ndarray warp_lm_grids: np.ndarray @@ -79,7 +80,7 @@ class ImageAugmentation(): def __init__(self, batchsize: int, processing_size: int, - config: Dict[str, "ConfigValueType"]) -> None: + config: dict[str, ConfigValueType]) -> None: logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, " "config: %s)", self.__class__.__name__, batchsize, processing_size, config) @@ -332,7 +333,7 @@ class ImageAugmentation(): slices = self._constants.warp_slices rands = np.random.normal(size=(self._batchsize, 2, 5, 5), scale=self._warp_scale).astype("float32") - batch_maps = ne.evaluate("m + r", local_dict=dict(m=self._constants.warp_maps, r=rands)) + batch_maps = ne.evaluate("m + r", local_dict={"m": self._constants.warp_maps, "r": rands}) batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices] for map_ in maps] for maps in batch_maps]) diff --git a/lib/training/cache.py b/lib/training/cache.py index 8bbd5431..3b391a02 100644 --- a/lib/training/cache.py +++ b/lib/training/cache.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 """ Holds the data cache for training data generators """ +from __future__ import annotations import logging import os -import sys +import typing as T from threading import Lock -from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING import cv2 import numpy as np @@ -16,25 +16,19 @@ from lib.align.aligned_face import CenteringType from lib.image import read_image_batch, read_image_meta_batch from lib.utils import FaceswapError -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict from lib.config import ConfigValueType logger = logging.getLogger(__name__) - -_FACE_CACHES: Dict[str, "_Cache"] = {} +_FACE_CACHES: dict[str, "_Cache"] = {} -def get_cache(side: Literal["a", "b"], - filenames: Optional[List[str]] = None, - config: Optional[Dict[str, "ConfigValueType"]] = None, - size: Optional[int] = None, - coverage_ratio: Optional[float] = None) -> "_Cache": +def get_cache(side: T.Literal["a", "b"], + filenames: list[str] | None = None, + config: dict[str, ConfigValueType] | None = None, + size: int | None = None, + coverage_ratio: float | None = None) -> "_Cache": """ Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then create it. @@ -120,24 +114,24 @@ class _Cache(): The coverage ratio that the model is using. """ def __init__(self, - filenames: List[str], - config: Dict[str, "ConfigValueType"], + filenames: list[str], + config: dict[str, ConfigValueType], size: int, coverage_ratio: float) -> None: logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)", self.__class__.__name__, len(filenames), size, coverage_ratio) self._lock = Lock() - self._cache_info = dict(cache_full=False, has_reset=False) - self._partially_loaded: List[str] = [] + self._cache_info = {"cache_full": False, "has_reset": False} + self._partially_loaded: list[str] = [] self._image_count = len(filenames) - self._cache: Dict[str, DetectedFace] = {} - self._aligned_landmarks: Dict[str, np.ndarray] = {} + self._cache: dict[str, DetectedFace] = {} + self._aligned_landmarks: dict[str, np.ndarray] = {} self._extract_version = 0.0 self._size = size - assert config["centering"] in get_args(CenteringType) - self._centering: CenteringType = cast(CenteringType, config["centering"]) + assert config["centering"] in T.get_args(CenteringType) + self._centering: CenteringType = T.cast(CenteringType, config["centering"]) self._config = config self._coverage_ratio = coverage_ratio @@ -153,7 +147,7 @@ class _Cache(): return self._cache_info["cache_full"] @property - def aligned_landmarks(self) -> Dict[str, np.ndarray]: + def aligned_landmarks(self) -> dict[str, np.ndarray]: """ dict: The filename as key, aligned landmarks as value. """ # Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate # all of the aligned landmarks for the entire cache. @@ -185,7 +179,7 @@ class _Cache(): self._cache_info["has_reset"] = False return retval - def get_items(self, filenames: List[str]) -> List[DetectedFace]: + def get_items(self, filenames: list[str]) -> list[DetectedFace]: """ Obtain the cached items for a list of filenames. The returned list is in the same order as the provided filenames. @@ -202,7 +196,7 @@ class _Cache(): """ return [self._cache[os.path.basename(filename)] for filename in filenames] - def cache_metadata(self, filenames: List[str]) -> np.ndarray: + def cache_metadata(self, filenames: list[str]) -> np.ndarray: """ Obtain the batch with metadata for items that need caching and cache DetectedFace objects to :attr:`_cache`. @@ -267,7 +261,7 @@ class _Cache(): return batch - def pre_fill(self, filenames: List[str], side: Literal["a", "b"]) -> None: + def pre_fill(self, filenames: list[str], side: T.Literal["a", "b"]) -> None: """ When warp to landmarks is enabled, the cache must be pre-filled, as each side needs access to the other side's alignments. @@ -294,7 +288,7 @@ class _Cache(): self._cache[key] = detected_face self._partially_loaded.append(key) - def _validate_version(self, png_meta: "PNGHeaderDict", filename: str) -> None: + def _validate_version(self, png_meta: PNGHeaderDict, filename: str) -> None: """ Validate that there are not a mix of v1.0 extracted faces and v2.x faces. Parameters @@ -350,7 +344,7 @@ class _Cache(): def _load_detected_face(self, filename: str, - alignments: "PNGHeaderAlignmentsDict") -> DetectedFace: + alignments: PNGHeaderAlignmentsDict) -> DetectedFace: """ Load a :class:`DetectedFace` object and load its associated `aligned` property. Parameters @@ -387,13 +381,13 @@ class _Cache(): The detected face object that holds the masks """ masks = [(self._get_face_mask(filename, detected_face))] - for area in get_args(Literal["eye", "mouth"]): + for area in T.get_args(T.Literal["eye", "mouth"]): masks.append(self._get_localized_mask(filename, detected_face, area)) detected_face.store_training_masks(masks, delete_masks=True) logger.trace("Stored masks for filename: %s)", filename) # type: ignore - def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> Optional[np.ndarray]: + def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> np.ndarray | None: """ Obtain the training sized face mask from the :class:`DetectedFace` for the requested mask type. @@ -448,7 +442,7 @@ class _Cache(): def _get_localized_mask(self, filename: str, detected_face: DetectedFace, - area: Literal["eye", "mouth"]) -> Optional[np.ndarray]: + area: T.Literal["eye", "mouth"]) -> np.ndarray | None: """ Obtain a localized mask for the given area if it is required for training. Parameters @@ -486,7 +480,7 @@ class RingBuffer(): # pylint: disable=too-few-public-methods """ def __init__(self, batch_size: int, - image_shape: Tuple[int, int, int], + image_shape: tuple[int, int, int], buffer_size: int = 2, dtype: str = "uint8") -> None: logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, " diff --git a/lib/training/generator.py b/lib/training/generator.py index cdc40134..8507a1cd 100644 --- a/lib/training/generator.py +++ b/lib/training/generator.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 """ Handles Data Augmentation for feeding Faceswap Models """ - +from __future__ import annotations import logging import os -import sys -from concurrent import futures +import typing as T +from concurrent import futures from random import shuffle, choice -from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING import cv2 import numpy as np @@ -21,18 +20,14 @@ from lib.utils import FaceswapError from . import ImageAugmentation from .cache import get_cache, RingBuffer -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from lib.config import ConfigValueType from plugins.train.model._base import ModelBase from .cache import _Cache logger = logging.getLogger(__name__) -BatchType = Tuple[np.ndarray, List[np.ndarray]] +BatchType = tuple[np.ndarray, list[np.ndarray]] class DataGenerator(): @@ -57,10 +52,10 @@ class DataGenerator(): objects of this size from the iterator. """ def __init__(self, - config: Dict[str, "ConfigValueType"], - model: "ModelBase", - side: Literal["a", "b"], - images: List[str], + config: dict[str, ConfigValueType], + model: ModelBase, + side: T.Literal["a", "b"], + images: list[str], batch_size: int) -> None: logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore "batch_size: %s, config: %s)", self.__class__.__name__, model.name, side, @@ -83,11 +78,11 @@ class DataGenerator(): self._buffer = RingBuffer(batch_size, (self._process_size, self._process_size, self._total_channels), dtype="uint8") - self._face_cache: "_Cache" = get_cache(side, - filenames=images, - config=self._config, - size=self._process_size, - coverage_ratio=self._coverage_ratio) + self._face_cache: _Cache = get_cache(side, + filenames=images, + config=self._config, + size=self._process_size, + coverage_ratio=self._coverage_ratio) logger.debug("Initialized %s", self.__class__.__name__) @property @@ -100,12 +95,12 @@ class DataGenerator(): channels += 1 mults = [area for area in ["eye", "mouth"] - if cast(int, self._config[f"{area}_multiplier"]) > 1] + if T.cast(int, self._config[f"{area}_multiplier"]) > 1] if self._config["penalized_mask_loss"] and mults: channels += len(mults) return channels - def _get_output_sizes(self, model: "ModelBase") -> List[int]: + def _get_output_sizes(self, model: ModelBase) -> list[int]: """ Obtain the size of each output tensor for the model. Parameters @@ -222,7 +217,7 @@ class DataGenerator(): retval = self._process_batch(img_paths) yield retval - def _get_images_with_meta(self, filenames: List[str]) -> Tuple[np.ndarray, List[DetectedFace]]: + def _get_images_with_meta(self, filenames: list[str]) -> tuple[np.ndarray, list[DetectedFace]]: """ Obtain the raw face images with associated :class:`DetectedFace` objects for this batch. @@ -253,9 +248,9 @@ class DataGenerator(): return raw_faces, detected_faces def _crop_to_coverage(self, - filenames: List[str], + filenames: list[str], images: np.ndarray, - detected_faces: List[DetectedFace], + detected_faces: list[DetectedFace], batch: np.ndarray) -> None: """ Crops the training image out of the full extract image based on the centering and coveage used in the user's configuration settings. @@ -286,7 +281,7 @@ class DataGenerator(): for future in futures.as_completed(proc): batch[proc[future], ..., :3] = future.result() - def _apply_mask(self, detected_faces: List[DetectedFace], batch: np.ndarray) -> None: + def _apply_mask(self, detected_faces: list[DetectedFace], batch: np.ndarray) -> None: """ Applies the masks to the 4th channel of the batch. If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1 @@ -312,7 +307,7 @@ class DataGenerator(): logger.trace("side: %s, masks: %s, batch: %s", # type: ignore self._side, masks.shape, batch.shape) - def _process_batch(self, filenames: List[str]) -> BatchType: + def _process_batch(self, filenames: list[str]) -> BatchType: """ Prepares data for feeding through subclassed methods. If this is the first time a face has been loaded, then it's meta data is extracted from the @@ -345,9 +340,9 @@ class DataGenerator(): return feed, targets def process_batch(self, - filenames: List[str], + filenames: list[str], images: np.ndarray, - detected_faces: List[DetectedFace], + detected_faces: list[DetectedFace], batch: np.ndarray) -> BatchType: """ Override for processing the batch for the current generator. @@ -391,7 +386,7 @@ class DataGenerator(): The input uint8 array """ return ne.evaluate("x / c", - local_dict=dict(x=in_array, c=np.float32(255)), + local_dict={"x": in_array, "c": np.float32(255)}, casting="unsafe") @@ -417,10 +412,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met objects of this size from the iterator. """ def __init__(self, - config: Dict[str, "ConfigValueType"], - model: "ModelBase", - side: Literal["a", "b"], - images: List[str], + config: dict[str, ConfigValueType], + model: ModelBase, + side: T.Literal["a", "b"], + images: list[str], batch_size: int) -> None: super().__init__(config, model, side, images, batch_size) self._augment_color = not model.command_line_arguments.no_augment_color @@ -434,10 +429,10 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met self._processing = ImageAugmentation(batch_size, self._process_size, self._config) - self._nearest_landmarks: Dict[str, Tuple[str, ...]] = {} + self._nearest_landmarks: dict[str, tuple[str, ...]] = {} logger.debug("Initialized %s", self.__class__.__name__) - def _create_targets(self, batch: np.ndarray) -> List[np.ndarray]: + def _create_targets(self, batch: np.ndarray) -> list[np.ndarray]: """ Compile target images, with masks, for the model output sizes. Parameters @@ -467,9 +462,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met return retval def process_batch(self, - filenames: List[str], + filenames: list[str], images: np.ndarray, - detected_faces: List[DetectedFace], + detected_faces: list[DetectedFace], batch: np.ndarray) -> BatchType: """ Performs the augmentation and compiles target images and samples. @@ -525,7 +520,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met if self._warp_to_landmarks: landmarks = np.array([face.aligned.landmarks for face in detected_faces]) batch_dst_pts = self._get_closest_match(filenames, landmarks) - warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts) + warp_kwargs = {"batch_src_points": landmarks, "batch_dst_points": batch_dst_pts} else: warp_kwargs = {} @@ -545,7 +540,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met return feed, targets - def _get_closest_match(self, filenames: List[str], batch_src_points: np.ndarray) -> np.ndarray: + def _get_closest_match(self, filenames: list[str], batch_src_points: np.ndarray) -> np.ndarray: """ Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest matched 68 point landmarks from the opposite training set. @@ -563,7 +558,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met """ logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore "src_points: '%s')", filenames, batch_src_points) - lm_side: Literal["a", "b"] = "a" if self._side == "b" else "b" + lm_side: T.Literal["a", "b"] = "a" if self._side == "b" else "b" other_cache = get_cache(lm_side) landmarks = other_cache.aligned_landmarks @@ -584,9 +579,9 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met return batch_dst_points def _cache_closest_matches(self, - filenames: List[str], + filenames: list[str], batch_src_points: np.ndarray, - landmarks: Dict[str, np.ndarray]) -> List[Tuple[str, ...]]: + landmarks: dict[str, np.ndarray]) -> list[tuple[str, ...]]: """ Cache the nearest landmarks for this batch Parameters @@ -602,7 +597,7 @@ class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-met logger.trace("Caching closest matches") # type:ignore dst_landmarks = list(landmarks.items()) dst_points = np.array([lm[1] for lm in dst_landmarks]) - batch_closest_matches: List[Tuple[str, ...]] = [] + batch_closest_matches: list[tuple[str, ...]] = [] for filename, src_points in zip(filenames, batch_src_points): closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] @@ -637,7 +632,7 @@ class PreviewDataGenerator(DataGenerator): """ def _create_samples(self, images: np.ndarray, - detected_faces: List[DetectedFace]) -> List[np.ndarray]: + detected_faces: list[DetectedFace]) -> list[np.ndarray]: """ Compile the 'sample' images. These are the 100% coverage images which hold the model output in the preview window. @@ -658,24 +653,25 @@ class PreviewDataGenerator(DataGenerator): output_size = self._output_sizes[-1] full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2)) - assert self._config["centering"] in get_args(CenteringType) + assert self._config["centering"] in T.get_args(CenteringType) retval = np.empty((full_size, full_size, 3), dtype="float32") - retval = self._to_float32(np.array([AlignedFace(face.landmarks_xy, - image=images[idx], - centering=cast(CenteringType, - self._config["centering"]), - size=full_size, - dtype="uint8", - is_aligned=True).face - for idx, face in enumerate(detected_faces)])) + retval = self._to_float32(np.array([ + AlignedFace(face.landmarks_xy, + image=images[idx], + centering=T.cast(CenteringType, + self._config["centering"]), + size=full_size, + dtype="uint8", + is_aligned=True).face + for idx, face in enumerate(detected_faces)])) logger.trace("Processed samples: %s", retval.shape) # type: ignore return [retval] def process_batch(self, - filenames: List[str], + filenames: list[str], images: np.ndarray, - detected_faces: List[DetectedFace], + detected_faces: list[DetectedFace], batch: np.ndarray) -> BatchType: """ Creates the full size preview images and the sub-cropped images for feeding the model's predict function. diff --git a/lib/training/preview_cv.py b/lib/training/preview_cv.py index 948edfff..c0a6458a 100644 --- a/lib/training/preview_cv.py +++ b/lib/training/preview_cv.py @@ -4,36 +4,30 @@ If Tkinter is installed, then this will be used to manage the preview image, otherwise we fallback to opencv's imshow """ +from __future__ import annotations import logging -import sys +import typing as T from threading import Event, Lock from time import sleep -from typing import Dict, Generator, List, Optional, Tuple, TYPE_CHECKING - import cv2 -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator import numpy as np logger = logging.getLogger(__name__) -TriggerType = Dict[Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event] -TriggerKeysType = Literal["m", "r", "s", "enter"] -TriggerNamesType = Literal["toggle_mask", "refresh", "save", "quit"] +TriggerType = dict[T.Literal["toggle_mask", "refresh", "save", "quit", "shutdown"], Event] +TriggerKeysType = T.Literal["m", "r", "s", "enter"] +TriggerNamesType = T.Literal["toggle_mask", "refresh", "save", "quit"] class PreviewBuffer(): """ A thread safe class for holding preview images """ def __init__(self) -> None: logger.debug("Initializing: %s", self.__class__.__name__) - self._images: Dict[str, "np.ndarray"] = {} + self._images: dict[str, np.ndarray] = {} self._lock = Lock() self._updated = Event() logger.debug("Initialized: %s", self.__class__.__name__) @@ -43,7 +37,7 @@ class PreviewBuffer(): """ bool: ``True`` when new images have been loaded into the preview buffer """ return self._updated.is_set() - def add_image(self, name: str, image: "np.ndarray") -> None: + def add_image(self, name: str, image: np.ndarray) -> None: """ Add an image to the preview buffer in a thread safe way """ logger.debug("Adding image: (name: '%s', shape: %s)", name, image.shape) with self._lock: @@ -51,7 +45,7 @@ class PreviewBuffer(): logger.debug("Added images: %s", list(self._images)) self._updated.set() - def get_images(self) -> Generator[Tuple[str, "np.ndarray"], None, None]: + def get_images(self) -> Generator[tuple[str, np.ndarray], None, None]: """ Get the latest images from the preview buffer. When iterator is exhausted clears the :attr:`updated` event. @@ -86,15 +80,15 @@ class PreviewBase(): # pylint:disable=too-few-public-methods """ def __init__(self, preview_buffer: PreviewBuffer, - triggers: Optional[TriggerType] = None) -> None: + triggers: TriggerType | None = None) -> None: logger.debug("Initializing %s parent (triggers: %s)", self.__class__.__name__, triggers) self._triggers = triggers self._buffer = preview_buffer - self._keymaps: Dict[TriggerKeysType, TriggerNamesType] = dict(m="toggle_mask", - r="refresh", - s="save", - enter="quit") + self._keymaps: dict[TriggerKeysType, TriggerNamesType] = {"m": "toggle_mask", + "r": "refresh", + "s": "save", + "enter": "quit"} self._title = "" logger.debug("Initialized %s parent", self.__class__.__name__) @@ -141,7 +135,7 @@ class PreviewCV(PreviewBase): # pylint:disable=too-few-public-methods logger.debug("Unable to import Tkinter. Falling back to OpenCV") super().__init__(preview_buffer, triggers=triggers) self._triggers: TriggerType = self._triggers - self._windows: List[str] = [] + self._windows: list[str] = [] self._lookup = {ord(key): val for key, val in self._keymaps.items() if key != "enter"} diff --git a/lib/training/preview_tk.py b/lib/training/preview_tk.py index 3b567ba1..a22fd3c0 100644 --- a/lib/training/preview_tk.py +++ b/lib/training/preview_tk.py @@ -4,24 +4,25 @@ If Tkinter is installed, then this will be used to manage the preview image, otherwise we fallback to opencv's imshow """ +from __future__ import annotations import logging import os import sys import tkinter as tk +import typing as T from datetime import datetime from platform import system from tkinter import ttk from math import ceil, floor -from typing import cast, List, Optional, Tuple, TYPE_CHECKING from PIL import Image, ImageTk import cv2 from .preview_cv import PreviewBase, TriggerKeysType -if TYPE_CHECKING: +if T.TYPE_CHECKING: import numpy as np from .preview_cv import PreviewBuffer, TriggerType @@ -38,18 +39,18 @@ class _Taskbar(): taskbar: :class:`tkinter.ttk.Frame` or ``None`` None if preview is a pop-up window otherwise ttk.Frame if taskbar is managed by the GUI """ - def __init__(self, parent: tk.Frame, taskbar: Optional[ttk.Frame]) -> None: + def __init__(self, parent: tk.Frame, taskbar: ttk.Frame | None) -> None: logger.debug("Initializing %s (parent: '%s', taskbar: %s)", self.__class__.__name__, parent, taskbar) self._is_standalone = taskbar is None - self._gui_mapped: List[tk.Widget] = [] + self._gui_mapped: list[tk.Widget] = [] self._frame = tk.Frame(parent) if taskbar is None else taskbar self._min_max_scales = (20, 400) - self._vars = dict(save=tk.BooleanVar(), - scale=tk.StringVar(), - slider=tk.IntVar(), - interpolator=tk.IntVar()) + self._vars = {"save": tk.BooleanVar(), + "scale": tk.StringVar(), + "slider": tk.IntVar(), + "interpolator": tk.IntVar()} self._interpolators = [("nearest_neighbour", cv2.INTER_NEAREST), ("bicubic", cv2.INTER_CUBIC)] self._scale = self._add_scale_combo() @@ -261,7 +262,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors def __init__(self, parent: tk.Frame, scale_var: tk.StringVar, - screen_dimensions: Tuple[int, int], + screen_dimensions: tuple[int, int], is_standalone: bool) -> None: logger.debug("Initializing %s (parent: '%s', scale_var: %s, screen_dimensions: %s)", self.__class__.__name__, parent, scale_var, screen_dimensions) @@ -272,7 +273,7 @@ class _PreviewCanvas(tk.Canvas): # pylint:disable=too-many-ancestors self._screen_dimensions = screen_dimensions self._var_scale = scale_var self._configure_scrollbars(frame) - self._image: Optional[ImageTk.PhotoImage] = None + self._image: ImageTk.PhotoImage | None = None self._image_id = self.create_image(self.width / 2, self.height / 2, anchor=tk.CENTER, @@ -400,8 +401,8 @@ class _Image(): logger.debug("Initializing %s: (save_variable: %s, is_standalone: %s)", self.__class__.__name__, save_variable, is_standalone) self._is_standalone = is_standalone - self._source: Optional["np.ndarray"] = None - self._display: Optional[ImageTk.PhotoImage] = None + self._source: np.ndarray | None = None + self._display: ImageTk.PhotoImage | None = None self._scale = 1.0 self._interpolation = cv2.INTER_NEAREST @@ -416,7 +417,7 @@ class _Image(): return self._display @property - def source(self) -> "np.ndarray": + def source(self) -> np.ndarray: """ :class:`PIL.Image.Image`: The current source preview image """ assert self._source is not None return self._source @@ -426,7 +427,7 @@ class _Image(): """int: The current display scale as a percentage of original image size """ return int(self._scale * 100) - def set_source_image(self, name: str, image: "np.ndarray") -> None: + def set_source_image(self, name: str, image: np.ndarray) -> None: """ Set the source image to :attr:`source` Parameters @@ -542,7 +543,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods self._taskbar = taskbar self._image = image - self._drag_data: List[float] = [0., 0.] + self._drag_data: list[float] = [0., 0.] self._set_mouse_bindings() self._set_key_bindings(is_standalone) logger.debug("Initialized %s", self.__class__.__name__,) @@ -604,7 +605,7 @@ class _Bindings(): # pylint: disable=too-few-public-methods The key press event """ move_axis = self._canvas.xview if event.keysym in ("Left", "Right") else self._canvas.yview - visible = (move_axis()[1] - move_axis()[0]) + visible = move_axis()[1] - move_axis()[0] amount = -visible / 25 if event.keysym in ("Up", "Left") else visible / 25 logger.trace("Key move event: (event: %s, move_axis: %s, visible: %s, " # type: ignore "amount: %s)", move_axis, visible, amount) @@ -671,10 +672,10 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods Default: `None` """ def __init__(self, - preview_buffer: "PreviewBuffer", - parent: Optional[tk.Widget] = None, - taskbar: Optional[ttk.Frame] = None, - triggers: Optional["TriggerType"] = None) -> None: + preview_buffer: PreviewBuffer, + parent: tk.Widget | None = None, + taskbar: ttk.Frame | None = None, + triggers: TriggerType | None = None) -> None: logger.debug("Initializing %s (parent: '%s')", self.__class__.__name__, parent) super().__init__(preview_buffer, triggers=triggers) self._is_standalone = parent is None @@ -745,7 +746,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods logger.info(" Save Preview: Ctrl+s") logger.info("---------------------------------------------------") - def _get_geometry(self) -> Tuple[int, int]: + def _get_geometry(self) -> tuple[int, int]: """ Obtain the geometry of the current screen (standalone) or the dimensions of the widget holding the preview window (GUI). @@ -780,7 +781,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods half_screen = tuple(x // 2 for x in self._screen_dimensions) min_scales = (half_screen[0] / self._image.source.shape[1], half_screen[1] / self._image.source.shape[0]) - min_scale = min(1.0, min(min_scales)) + min_scale = min(1.0, *min_scales) min_scale = (ceil(min_scale * 10)) * 10 eight_screen = tuple(x * 8 for x in self._screen_dimensions) @@ -884,7 +885,7 @@ class PreviewTk(PreviewBase): # pylint:disable=too-few-public-methods if self._triggers is None: # Don't need triggers for GUI return keypress = "enter" if event.keysym == "Return" else event.keysym - key = cast(TriggerKeysType, keypress) + key = T.cast(TriggerKeysType, keypress) logger.debug("Processing keypress '%s'", key) if key == "r": print("") # Let log print on different line from loss output diff --git a/lib/utils.py b/lib/utils.py index ad232a6b..a81000e6 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,11 +1,12 @@ #!/usr/bin python3 """ Utilities available across all scripts """ - +from __future__ import annotations import json import logging import os import sys import tkinter as tk +import typing as T import warnings import zipfile @@ -14,18 +15,12 @@ from re import finditer from socket import timeout as socket_timeout, error as socket_error from threading import get_ident from time import time -from typing import cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING from urllib import request, error as urlliberror import numpy as np from tqdm import tqdm -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from http.client import HTTPResponse # Global variables @@ -34,8 +29,8 @@ _image_extensions = [ # pylint:disable=invalid-name _video_extensions = [ # pylint:disable=invalid-name ".avi", ".flv", ".mkv", ".mov", ".mp4", ".mpeg", ".mpg", ".webm", ".wmv", ".ts", ".vob"] -_TF_VERS: Optional[Tuple[int, int]] = None -ValidBackends = Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"] +_TF_VERS: tuple[int, int] | None = None +ValidBackends = T.Literal["nvidia", "cpu", "apple_silicon", "directml", "rocm"] class _Backend(): # pylint:disable=too-few-public-methods @@ -44,7 +39,7 @@ class _Backend(): # pylint:disable=too-few-public-methods If file doesn't exist and a variable hasn't been set, create the config file. """ def __init__(self) -> None: - self._backends: Dict[str, ValidBackends] = {"1": "cpu", + self._backends: dict[str, ValidBackends] = {"1": "cpu", "2": "directml", "3": "nvidia", "4": "apple_silicon", @@ -78,9 +73,9 @@ class _Backend(): # pylint:disable=too-few-public-methods """ # Check if environment variable is set, if so use that if "FACESWAP_BACKEND" in os.environ: - fs_backend = cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower()) - assert fs_backend in get_args(ValidBackends), ( - f"Faceswap backend must be one of {get_args(ValidBackends)}") + fs_backend = T.cast(ValidBackends, os.environ["FACESWAP_BACKEND"].lower()) + assert fs_backend in T.get_args(ValidBackends), ( + f"Faceswap backend must be one of {T.get_args(ValidBackends)}") print(f"Setting Faceswap backend from environment variable to {fs_backend.upper()}") return fs_backend # Intercept for sphinx docs build @@ -163,11 +158,11 @@ def set_backend(backend: str) -> None: >>> set_backend("nvidia") """ global _FS_BACKEND # pylint:disable=global-statement - backend = cast(ValidBackends, backend.lower()) + backend = T.cast(ValidBackends, backend.lower()) _FS_BACKEND = backend -def get_tf_version() -> Tuple[int, int]: +def get_tf_version() -> tuple[int, int]: """ Obtain the major. minor version of currently installed Tensorflow. Returns @@ -179,7 +174,7 @@ def get_tf_version() -> Tuple[int, int]: ------- >>> from lib.utils import get_tf_version >>> get_tf_version() - (2, 9) + (2, 10) """ global _TF_VERS # pylint:disable=global-statement if _TF_VERS is None: @@ -225,7 +220,7 @@ def get_folder(path: str, make_folder: bool = True) -> str: return path -def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str]: +def get_image_paths(directory: str, extension: str | None = None) -> list[str]: """ Gets the image paths from a given directory. The function searches for files with the specified extension(s) in the given directory, and @@ -274,7 +269,7 @@ def get_image_paths(directory: str, extension: Optional[str] = None) -> List[str return dir_contents -def get_dpi() -> Optional[float]: +def get_dpi() -> float | None: """ Gets the DPI (dots per inch) of the display screen. Returns @@ -338,7 +333,7 @@ def convert_to_secs(*args: int) -> int: return retval -def full_path_split(path: str) -> List[str]: +def full_path_split(path: str) -> list[str]: """ Split a file path into all of its parts. Parameters @@ -360,7 +355,7 @@ def full_path_split(path: str) -> List[str]: ['relative', 'path', 'to', 'file.txt']] """ logger = logging.getLogger(__name__) - allparts: List[str] = [] + allparts: list[str] = [] while True: parts = os.path.split(path) if parts[0] == path: # sentinel for absolute paths @@ -410,7 +405,7 @@ def set_system_verbosity(log_level: str): warnings.simplefilter(action='ignore', category=warncat) -def deprecation_warning(function: str, additional_info: Optional[str] = None) -> None: +def deprecation_warning(function: str, additional_info: str | None = None) -> None: """ Log a deprecation warning message. This function logs a warning message to indicate that the specified function has been @@ -436,7 +431,7 @@ def deprecation_warning(function: str, additional_info: Optional[str] = None) -> logger.warning(msg) -def camel_case_split(identifier: str) -> List[str]: +def camel_case_split(identifier: str) -> list[str]: """ Split a camelCase string into a list of its individual parts Parameters @@ -541,7 +536,7 @@ class GetModel(): # pylint:disable=too-few-public-methods >>> model_downloader = GetModel("s3fd_keras_v2.h5", 11) """ - def __init__(self, model_filename: Union[str, List[str]], git_model_id: int) -> None: + def __init__(self, model_filename: str | list[str], git_model_id: int) -> None: self.logger = logging.getLogger(__name__) if not isinstance(model_filename, list): model_filename = [model_filename] @@ -576,7 +571,7 @@ class GetModel(): # pylint:disable=too-few-public-methods return retval @property - def model_path(self) -> Union[str, List[str]]: + def model_path(self) -> str | list[str]: """ str or list[str]: The model path(s) in the cache folder. Example @@ -587,7 +582,7 @@ class GetModel(): # pylint:disable=too-few-public-methods '/path/to/s3fd_keras_v2.h5' """ paths = [os.path.join(self._cache_dir, fname) for fname in self._model_filename] - retval: Union[str, List[str]] = paths[0] if len(paths) == 1 else paths + retval: str | list[str] = paths[0] if len(paths) == 1 else paths self.logger.trace(retval) # type:ignore[attr-defined] return retval @@ -662,7 +657,7 @@ class GetModel(): # pylint:disable=too-few-public-methods self._url_download, self._cache_dir) sys.exit(1) - def _write_zipfile(self, response: "HTTPResponse", downloaded_size: int) -> None: + def _write_zipfile(self, response: HTTPResponse, downloaded_size: int) -> None: """ Write the model zip file to disk. Parameters @@ -762,8 +757,8 @@ class DebugTimes(): """ def __init__(self, show_min: bool = True, show_mean: bool = True, show_max: bool = True) -> None: - self._times: Dict[str, List[float]] = {} - self._steps: Dict[str, float] = {} + self._times: dict[str, list[float]] = {} + self._steps: dict[str, float] = {} self._interval = 1 self._display = {"min": show_min, "mean": show_mean, "max": show_max} diff --git a/locales/plugins.train._config.pot b/locales/plugins.train._config.pot index 90ff2866..26558b6e 100644 --- a/locales/plugins.train._config.pot +++ b/locales/plugins.train._config.pot @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-06-11 23:20+0100\n" +"POT-Creation-Date: 2023-06-25 13:39+0100\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -227,7 +227,8 @@ msgid "" msgstr "" #: plugins/train/_config.py:198 plugins/train/_config.py:223 -#: plugins/train/_config.py:238 plugins/train/_config.py:265 +#: plugins/train/_config.py:238 plugins/train/_config.py:256 +#: plugins/train/_config.py:290 msgid "optimizer" msgstr "" @@ -274,21 +275,40 @@ msgid "" "epsilon to 0.001 (1e-3)." msgstr "" -#: plugins/train/_config.py:258 +#: plugins/train/_config.py:262 msgid "" -"Apply AutoClipping to the gradients. AutoClip analyzes the " -"gradient weights and adjusts the normalization value dynamically to fit the " -"data. Can help prevent NaNs and improve model optimization at the expense of " -"VRAM. Ref: AutoClip: Adaptive Gradient Clipping for Source Separation " -"Networks https://arxiv.org/abs/2007.14469" +"When to save the Optimizer Weights. Saving the optimizer weights is not " +"necessary and will increase the model file size 3x (and by extension the " +"amount of time it takes to save the model). However, it can be useful to " +"save these weights if you want to guarantee that a resumed model carries off " +"exactly from where it left off, rather than spending a few hundred " +"iterations catching up.\n" +"\t never - Don't save optimizer weights.\n" +"\t always - Save the optimizer weights at every save iteration. Model saving " +"will take longer, due to the increased file size, but you will always have " +"the last saved optimizer state in your model file.\n" +"\t exit - Only save the optimizer weights when explicitly terminating a " +"model. This can be when the model is actively stopped or when the target " +"iterations are met. Note: If the training session ends because of another " +"reason (e.g. power outage, Out of Memory Error, NaN detected) then the " +"optimizer weights will NOT be saved." msgstr "" -#: plugins/train/_config.py:271 plugins/train/_config.py:283 -#: plugins/train/_config.py:297 plugins/train/_config.py:314 +#: plugins/train/_config.py:283 +msgid "" +"Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights " +"and adjusts the normalization value dynamically to fit the data. Can help " +"prevent NaNs and improve model optimization at the expense of VRAM. Ref: " +"AutoClip: Adaptive Gradient Clipping for Source Separation Networks https://" +"arxiv.org/abs/2007.14469" +msgstr "" + +#: plugins/train/_config.py:296 plugins/train/_config.py:308 +#: plugins/train/_config.py:322 plugins/train/_config.py:339 msgid "network" msgstr "" -#: plugins/train/_config.py:273 +#: plugins/train/_config.py:298 msgid "" "Use reflection padding rather than zero padding with convolutions. Each " "convolution must pad the image boundaries to maintain the proper sizing. " @@ -297,21 +317,21 @@ msgid "" "\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" msgstr "" -#: plugins/train/_config.py:286 +#: plugins/train/_config.py:311 msgid "" -"Enable the Tensorflow GPU 'allow_growth' configuration " -"option. This option prevents Tensorflow from allocating all of the GPU VRAM " -"at launch but can lead to higher VRAM fragmentation and slower performance. " -"Should only be enabled if you are receiving errors regarding 'cuDNN fails to " -"initialize' when commencing training." +"Enable the Tensorflow GPU 'allow_growth' configuration option. This option " +"prevents Tensorflow from allocating all of the GPU VRAM at launch but can " +"lead to higher VRAM fragmentation and slower performance. Should only be " +"enabled if you are receiving errors regarding 'cuDNN fails to initialize' " +"when commencing training." msgstr "" -#: plugins/train/_config.py:299 +#: plugins/train/_config.py:324 msgid "" -"NVIDIA GPUs can run operations in float16 faster than in " -"float32. Mixed precision allows you to use a mix of float16 with float32, to " -"get the performance benefits from float16 and the numeric stability benefits " -"from float32.\n" +"NVIDIA GPUs can run operations in float16 faster than in float32. Mixed " +"precision allows you to use a mix of float16 with float32, to get the " +"performance benefits from float16 and the numeric stability benefits from " +"float32.\n" "\n" "This is untested on DirectML backend, but will run on most Nvidia models. it " "will only speed up training on more recent GPUs. Those with compute " @@ -322,7 +342,7 @@ msgid "" "the most benefit." msgstr "" -#: plugins/train/_config.py:316 +#: plugins/train/_config.py:341 msgid "" "If a 'NaN' is generated in the model, this means that the model has " "corrupted and the model is likely to start deteriorating from this point on. " @@ -331,11 +351,11 @@ msgid "" "rescue your model." msgstr "" -#: plugins/train/_config.py:329 +#: plugins/train/_config.py:354 msgid "convert" msgstr "" -#: plugins/train/_config.py:331 +#: plugins/train/_config.py:356 msgid "" "[GPU Only]. The number of faces to feed through the model at once when " "running the Convert process.\n" @@ -345,27 +365,27 @@ msgid "" "size." msgstr "" -#: plugins/train/_config.py:350 +#: plugins/train/_config.py:375 msgid "" "Loss configuration options\n" "Loss is the mechanism by which a Neural Network judges how well it thinks " "that it is recreating a face." msgstr "" -#: plugins/train/_config.py:357 plugins/train/_config.py:369 -#: plugins/train/_config.py:382 plugins/train/_config.py:402 -#: plugins/train/_config.py:414 plugins/train/_config.py:434 -#: plugins/train/_config.py:446 plugins/train/_config.py:466 -#: plugins/train/_config.py:482 plugins/train/_config.py:498 -#: plugins/train/_config.py:515 +#: plugins/train/_config.py:382 plugins/train/_config.py:394 +#: plugins/train/_config.py:407 plugins/train/_config.py:427 +#: plugins/train/_config.py:439 plugins/train/_config.py:459 +#: plugins/train/_config.py:471 plugins/train/_config.py:491 +#: plugins/train/_config.py:507 plugins/train/_config.py:523 +#: plugins/train/_config.py:540 msgid "loss" msgstr "" -#: plugins/train/_config.py:361 +#: plugins/train/_config.py:386 msgid "The loss function to use." msgstr "" -#: plugins/train/_config.py:373 +#: plugins/train/_config.py:398 msgid "" "The second loss function to use. If using a structural based loss (such as " "SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " @@ -373,7 +393,7 @@ msgid "" "function with the loss_weight_2 option." msgstr "" -#: plugins/train/_config.py:388 +#: plugins/train/_config.py:413 msgid "" "The amount of weight to apply to the second loss function.\n" "\n" @@ -391,13 +411,13 @@ msgid "" "\t 0 - Disables the second loss function altogether." msgstr "" -#: plugins/train/_config.py:406 +#: plugins/train/_config.py:431 msgid "" "The third loss function to use. You can adjust the weighting of this loss " "function with the loss_weight_3 option." msgstr "" -#: plugins/train/_config.py:420 +#: plugins/train/_config.py:445 msgid "" "The amount of weight to apply to the third loss function.\n" "\n" @@ -415,13 +435,13 @@ msgid "" "\t 0 - Disables the third loss function altogether." msgstr "" -#: plugins/train/_config.py:438 +#: plugins/train/_config.py:463 msgid "" "The fourth loss function to use. You can adjust the weighting of this loss " "function with the loss_weight_3 option." msgstr "" -#: plugins/train/_config.py:452 +#: plugins/train/_config.py:477 msgid "" "The amount of weight to apply to the fourth loss function.\n" "\n" @@ -439,7 +459,7 @@ msgid "" "\t 0 - Disables the fourth loss function altogether." msgstr "" -#: plugins/train/_config.py:471 +#: plugins/train/_config.py:496 msgid "" "The loss function to use when learning a mask.\n" "\t MAE - Mean absolute error will guide reconstructions of each pixel " @@ -451,7 +471,7 @@ msgid "" "susceptible to outliers and typically produces slightly blurrier results." msgstr "" -#: plugins/train/_config.py:488 +#: plugins/train/_config.py:513 msgid "" "The amount of priority to give to the eyes.\n" "\n" @@ -464,7 +484,7 @@ msgid "" "NB: Penalized Mask Loss must be enable to use this option." msgstr "" -#: plugins/train/_config.py:504 +#: plugins/train/_config.py:529 msgid "" "The amount of priority to give to the mouth.\n" "\n" @@ -477,7 +497,7 @@ msgid "" "NB: Penalized Mask Loss must be enable to use this option." msgstr "" -#: plugins/train/_config.py:517 +#: plugins/train/_config.py:542 msgid "" "Image loss function is weighted by mask presence. For areas of the image " "without the facial mask, reconstruction errors will be ignored while the " @@ -485,12 +505,12 @@ msgid "" "attention on the core face area." msgstr "" -#: plugins/train/_config.py:528 plugins/train/_config.py:570 -#: plugins/train/_config.py:584 plugins/train/_config.py:593 +#: plugins/train/_config.py:553 plugins/train/_config.py:595 +#: plugins/train/_config.py:609 plugins/train/_config.py:618 msgid "mask" msgstr "" -#: plugins/train/_config.py:531 +#: plugins/train/_config.py:556 msgid "" "The mask to be used for training. If you have selected 'Learn Mask' or " "'Penalized Mask Loss' you must select a value other than 'none'. The " @@ -528,7 +548,7 @@ msgid "" "performance." msgstr "" -#: plugins/train/_config.py:572 +#: plugins/train/_config.py:597 msgid "" "Apply gaussian blur to the mask input. This has the effect of smoothing the " "edges of the mask, which can help with poorly calculated masks and give less " @@ -538,13 +558,13 @@ msgid "" "number." msgstr "" -#: plugins/train/_config.py:586 +#: plugins/train/_config.py:611 msgid "" "Sets pixels that are near white to white and near black to black. Set to 0 " "for off." msgstr "" -#: plugins/train/_config.py:595 +#: plugins/train/_config.py:620 msgid "" "Dedicate a portion of the model to learning how to duplicate the input mask. " "Increases VRAM usage in exchange for learning a quick ability to try to " diff --git a/locales/ru/LC_MESSAGES/plugins.train._config.mo b/locales/ru/LC_MESSAGES/plugins.train._config.mo index 7cbd6379b30ee71b819b9fb10f9eb9530932a00e..a709331fd7fae5284c882b07c5763d67ace196dc 100644 GIT binary patch delta 3822 zcmai#e~eUD6~_<1T7HNKg%+!oTPSqF-6?dzLjMq2Xen)3tVJc0}JdAO&UwIw(mt$sEvunRExi2QWG`S_&N8@4!eQG zm)!Zjckj99e9!lsd*_v-6Z+0h2!Atc;z8kA$@nPa?QtTTz#|j*;ki6f#09T_Ygt=8 zNu(R>xk;p-`I{z-+zk$c4>11+@D*_JRFO}Dm#2xWA1`v-bdeYF`Sc8t70kasOXNQy zp)}4Gxs!>db431z;iqpE*~j;T9~SvM7~M`l0@TCqGI#=9&3x}2A|<{reoW+L_I+WV z$T8+8&lfSw_bm`P3|<0zLGLb+C&7P$?|_dl6xlp36lq)}@>3>yNuvn|o9jfL0Kd0b zWEYPAu_VFg8$@>UJ#MMUpTX{XMSOLrjFCGH1QWJPdAmMC2SB8aIec;QKF+i{MB8zEK4AHfO0vH@r2gBO_+^(!ijAv!M8>l*L%|+m!?SzX1fB%{2%gw0@^kReJ_h@8 z`w2*ZNnaB=z<1x*MV9lu_8TILz#ZQdc@*N`IRrWI>X8JO9i_L3x8n!s|0WJE9TRz% zz~6sa~pSf zwz%2WOi)OfC!Oute5+mcYO5W|s>v2i-V03LZFUO<$L};wK5aU(xtz)7n|;?QByQWh zw3{<6*_>+%csGqZOdXb?+16>?9f6xKWIc#Lz-jY}`M`KB1|qJ>1}1QxbqixiT5pR} z<`J*M-R}DJYSrxI%{JE*3vNp>H`*EoHL@+H(<_<|hX5qqS_~bZcUL(&fpLuQ7K&|d z+9Y(F9pBFq+G}Znq|+SaI!%l3wV95L>sx9$w?(P8)EnP{ysv)prqJ%@(<+%`THFqk zDdy8YPP2iF6?qmQ8#HII2u-m)HFXN|k%nn7>%9EkK~jwCq*gm8*Wq*)Fr+4lu-BOx z5>aoRiJMx`l-X!&K&6$MEm066t;u=$RuwT_#HDA8FjBkpwUX7FZN(swPpeyHk{RKV za|#LZw7Ev95I6xk<#n516;*pcQ*n1tw6MW!%A;zn{jTjDsTo@5F>cRgo3mN~L~P6E z(b4K_3zagPGxTA#1MCLvq@;)}P`T`OSHlEedpnizd@UjU7Gf4Seybadu8u=-+uR^! zn!Lch-)x{2)y{Xad5v12)3sYEa)Fj>cIY|nSBHb3J)m9Ba-ExMO__FD)Hhx+a9Z7Z zvk3{Y+~~G>6zCy;jqf3Rlhb60KJv{2=nHf`2MSIN?;zNs4?9)I%+ zv#jy{<@fe&dF)WvqUdlm6!k@?qq2#+qM>+KyfZ3CgYkj5JKk%eQ&ENKa_NbIxs%E% z6CEv`%*>vW%nn7TbUrMbXfUdn_yE=y_^q(DD?VUnaegx1#qa5;KN^VlL<1)3H*t?T zjxI#~@e7P@V}qWJ2H3mP#9xKMIaYBzV7FRR$e6|v%d)9cKQNBMp}5=f)Ts6iI)l*= z_PDJO^+5d%xUjQgpatVGnX_XMPS76}o7*FL3FASSombboc1ANbQEB7uIpNS)ALN93 z@jm$U*F=Pcib>W=ZI4l1)PMn#zNlQ4UZ-~A{}eO0sW_5Mv7aN(cS_Y zgoCb%=vA>(Ww2qOH|kU&;Boa)TC;n0NNUX=v3)5Ws)Qdz$ErMr}3TNlZ*kyka#5{ZOe_Fh+wY2`WhJfvAWV2t`Uf+Ipv@ z^@xI?A3QKBktdo65|0lp-coI=P${JZwFW^1MOyT?dCyLm_xsQ6?#%A&Y~tr5rw>Q! z9`%auK`B!Qs`pZ0hTIwPM`?<}7xIjp;#gKE*d!NrhGzYb(x6c8ONUSNx60Epr7L_a zn_^I&3ePj(pfSrbp;Z6aUh^1_@Dy_2A6R(;CogF_q z8CE-fYKk#kpFI`+Ge7OqHJtEIIYea4`n#UY3QsMwWOq4GUY5g?{`Mqa$T0tv>>}UF zH?qsjM8%?&UX2O=kPk!RD)`JiZ{}G48YeEYE+-r~C#&7@;%X?;Zk`Xz+`(T9;jH&P ztJV_gPhA4tWc5;5DL-HCsqBk;E8M*6H?6gQ^w+P0YsMe&CA65ZO@{S{*2AaTFTe8s zN4c=ZGNiij+pl4o3$|^sH1f85EN^;&zsPgjpuvUPx0}#Fc{|}d?Ys5fzxLwqp+N5a z0V>@1#oaJnrteR%>YzPkzP*QyZM+vpP5i{fSB`(^b diff --git a/locales/ru/LC_MESSAGES/plugins.train._config.po b/locales/ru/LC_MESSAGES/plugins.train._config.po index fba3a9d5..02fc12a2 100644 --- a/locales/ru/LC_MESSAGES/plugins.train._config.po +++ b/locales/ru/LC_MESSAGES/plugins.train._config.po @@ -7,8 +7,8 @@ msgid "" msgstr "" "Project-Id-Version: \n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2023-06-11 23:20+0100\n" -"PO-Revision-Date: 2023-06-20 17:06+0100\n" +"POT-Creation-Date: 2023-06-25 13:39+0100\n" +"PO-Revision-Date: 2023-06-25 13:42+0100\n" "Last-Translator: \n" "Language-Team: \n" "Language: ru_RU\n" @@ -354,7 +354,8 @@ msgstr "" "повлияет только на запуск новой модели." #: plugins/train/_config.py:198 plugins/train/_config.py:223 -#: plugins/train/_config.py:238 plugins/train/_config.py:265 +#: plugins/train/_config.py:238 plugins/train/_config.py:256 +#: plugins/train/_config.py:290 msgid "optimizer" msgstr "оптимизатор" @@ -435,7 +436,41 @@ msgstr "" "Например, при выборе значения '-7' эпсилон будет равен 1e-7. При выборе " "значения \"-3\" эпсилон будет равен 0,001 (1e-3)." -#: plugins/train/_config.py:258 +#: plugins/train/_config.py:262 +msgid "" +"When to save the Optimizer Weights. Saving the optimizer weights is not " +"necessary and will increase the model file size 3x (and by extension the " +"amount of time it takes to save the model). However, it can be useful to " +"save these weights if you want to guarantee that a resumed model carries off " +"exactly from where it left off, rather than spending a few hundred " +"iterations catching up.\n" +"\t never - Don't save optimizer weights.\n" +"\t always - Save the optimizer weights at every save iteration. Model saving " +"will take longer, due to the increased file size, but you will always have " +"the last saved optimizer state in your model file.\n" +"\t exit - Only save the optimizer weights when explicitly terminating a " +"model. This can be when the model is actively stopped or when the target " +"iterations are met. Note: If the training session ends because of another " +"reason (e.g. power outage, Out of Memory Error, NaN detected) then the " +"optimizer weights will NOT be saved." +msgstr "" +"Когда сохранять веса оптимизатора. Сохранение весов оптимизатора не является " +"необходимым и увеличит размер файла модели в 3 раза (и соответственно время, " +"необходимое для сохранения модели). Однако может быть полезно сохранить эти " +"веса, если вы хотите гарантировать, что возобновленная модель продолжит " +"работу именно с того места, где она остановилась, а не тратит несколько " +"сотен итераций на догонялки.\n" +"\t never - не сохранять веса оптимизатора.\n" +"\t always - сохранять веса оптимизатора при каждой итерации сохранения. " +"Сохранение модели займет больше времени из-за увеличенного размера файла, но " +"в файле модели всегда будет последнее сохраненное состояние оптимизатора.\n" +"\t exit - сохранять веса оптимизатора только при явном завершении модели. " +"Это может быть, когда модель активно останавливается или когда выполняются " +"целевые итерации. Примечание. Если сеанс обучения завершается по другой " +"причине (например, отключение питания, ошибка нехватки памяти, обнаружение " +"NaN), веса оптимизатора НЕ будут сохранены." + +#: plugins/train/_config.py:283 msgid "" "Apply AutoClipping to the gradients. AutoClip analyzes the gradient weights " "and adjusts the normalization value dynamically to fit the data. Can help " @@ -449,12 +484,12 @@ msgstr "" "ценой видеопамяти. Ссылка: AutoClip: Adaptive Gradient Clipping for Source " "Separation Networks [ТОЛЬКО на английском] https://arxiv.org/abs/2007.14469" -#: plugins/train/_config.py:271 plugins/train/_config.py:283 -#: plugins/train/_config.py:297 plugins/train/_config.py:314 +#: plugins/train/_config.py:296 plugins/train/_config.py:308 +#: plugins/train/_config.py:322 plugins/train/_config.py:339 msgid "network" msgstr "сеть" -#: plugins/train/_config.py:273 +#: plugins/train/_config.py:298 msgid "" "Use reflection padding rather than zero padding with convolutions. Each " "convolution must pad the image boundaries to maintain the proper sizing. " @@ -468,7 +503,7 @@ msgstr "" "изображения.\n" "\t http://www-cs.engr.ccny.cuny.edu/~wolberg/cs470/hw/hw2_pad.txt" -#: plugins/train/_config.py:286 +#: plugins/train/_config.py:311 msgid "" "Enable the Tensorflow GPU 'allow_growth' configuration option. This option " "prevents Tensorflow from allocating all of the GPU VRAM at launch but can " @@ -483,7 +518,7 @@ msgstr "" "случае, если у вас появляются ошибки, рода 'cuDNN fails to initialize'(cuDNN " "не может инициализироваться) при начале тренировки." -#: plugins/train/_config.py:299 +#: plugins/train/_config.py:324 msgid "" "NVIDIA GPUs can run operations in float16 faster than in float32. Mixed " "precision allows you to use a mix of float16 with float32, to get the " @@ -512,7 +547,7 @@ msgstr "" "ускорение. В основном RTX видеокарты и позже предлагают самое большое " "ускорение." -#: plugins/train/_config.py:316 +#: plugins/train/_config.py:341 msgid "" "If a 'NaN' is generated in the model, this means that the model has " "corrupted and the model is likely to start deteriorating from this point on. " @@ -526,11 +561,11 @@ msgstr "" "NaN. Последнее сохранение не будет содержать в себе NaN, так что у вас будет " "возможность спасти вашу модель." -#: plugins/train/_config.py:329 +#: plugins/train/_config.py:354 msgid "convert" msgstr "конвертирование" -#: plugins/train/_config.py:331 +#: plugins/train/_config.py:356 msgid "" "[GPU Only]. The number of faces to feed through the model at once when " "running the Convert process.\n" @@ -546,7 +581,7 @@ msgstr "" "конвертирования, однако, если у вас появляются ошибки 'Out of Memory', тогда " "стоит снизить размер пачки." -#: plugins/train/_config.py:350 +#: plugins/train/_config.py:375 msgid "" "Loss configuration options\n" "Loss is the mechanism by which a Neural Network judges how well it thinks " @@ -556,20 +591,20 @@ msgstr "" "Потеря - механизм, по которому Нейронная Сеть судит, насколько хорошо она " "воспроизводит лицо." -#: plugins/train/_config.py:357 plugins/train/_config.py:369 -#: plugins/train/_config.py:382 plugins/train/_config.py:402 -#: plugins/train/_config.py:414 plugins/train/_config.py:434 -#: plugins/train/_config.py:446 plugins/train/_config.py:466 -#: plugins/train/_config.py:482 plugins/train/_config.py:498 -#: plugins/train/_config.py:515 +#: plugins/train/_config.py:382 plugins/train/_config.py:394 +#: plugins/train/_config.py:407 plugins/train/_config.py:427 +#: plugins/train/_config.py:439 plugins/train/_config.py:459 +#: plugins/train/_config.py:471 plugins/train/_config.py:491 +#: plugins/train/_config.py:507 plugins/train/_config.py:523 +#: plugins/train/_config.py:540 msgid "loss" msgstr "потери" -#: plugins/train/_config.py:361 +#: plugins/train/_config.py:386 msgid "The loss function to use." msgstr "Какую функцию потерь стоит использовать." -#: plugins/train/_config.py:373 +#: plugins/train/_config.py:398 msgid "" "The second loss function to use. If using a structural based loss (such as " "SSIM, MS-SSIM or GMSD) it is common to add an L1 regularization(MAE) or L2 " @@ -581,7 +616,7 @@ msgstr "" "регуляризации L1 (MAE) или регуляризации L2 (MSE). Вы можете настроить вес " "этой функции потерь с помощью параметра loss_weight_2." -#: plugins/train/_config.py:388 +#: plugins/train/_config.py:413 msgid "" "The amount of weight to apply to the second loss function.\n" "\n" @@ -612,7 +647,7 @@ msgstr "" "4 раза перед добавлением к общей оценке потерь. \n" "\t 0 - Полностью отключает четвертую функцию потерь." -#: plugins/train/_config.py:406 +#: plugins/train/_config.py:431 msgid "" "The third loss function to use. You can adjust the weighting of this loss " "function with the loss_weight_3 option." @@ -620,7 +655,7 @@ msgstr "" "Третья используемая функция потерь. Вы можете настроить вес этой функции " "потерь с помощью параметра loss_weight_3." -#: plugins/train/_config.py:420 +#: plugins/train/_config.py:445 msgid "" "The amount of weight to apply to the third loss function.\n" "\n" @@ -651,7 +686,7 @@ msgstr "" "4 раза перед добавлением к общей оценке потерь. \n" "\t 0 - Полностью отключает четвертую функцию потерь." -#: plugins/train/_config.py:438 +#: plugins/train/_config.py:463 msgid "" "The fourth loss function to use. You can adjust the weighting of this loss " "function with the loss_weight_3 option." @@ -659,7 +694,7 @@ msgstr "" "Четвертая используемая функция потерь. Вы можете настроить вес этой функции " "потерь с помощью параметра 'loss_weight_4'." -#: plugins/train/_config.py:452 +#: plugins/train/_config.py:477 msgid "" "The amount of weight to apply to the fourth loss function.\n" "\n" @@ -690,7 +725,7 @@ msgstr "" "4 раза перед добавлением к общей оценке потерь. \n" "\t 0 - Полностью отключает четвертую функцию потерь." -#: plugins/train/_config.py:471 +#: plugins/train/_config.py:496 msgid "" "The loss function to use when learning a mask.\n" "\t MAE - Mean absolute error will guide reconstructions of each pixel " @@ -711,7 +746,7 @@ msgstr "" "данных. Как среднее значение, оно чувствительно к выбросам и обычно дает " "немного более размытые результаты." -#: plugins/train/_config.py:488 +#: plugins/train/_config.py:513 msgid "" "The amount of priority to give to the eyes.\n" "\n" @@ -731,7 +766,7 @@ msgstr "" "\n" "NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." -#: plugins/train/_config.py:504 +#: plugins/train/_config.py:529 msgid "" "The amount of priority to give to the mouth.\n" "\n" @@ -751,7 +786,7 @@ msgstr "" "\n" "NB: Penalized Mask Loss должен быть включен, чтобы использовать эту опцию." -#: plugins/train/_config.py:517 +#: plugins/train/_config.py:542 msgid "" "Image loss function is weighted by mask presence. For areas of the image " "without the facial mask, reconstruction errors will be ignored while the " @@ -763,12 +798,12 @@ msgstr "" "время как область лица с маской является приоритетной. Может повысить общее " "качество за счет концентрации внимания на основной области лица." -#: plugins/train/_config.py:528 plugins/train/_config.py:570 -#: plugins/train/_config.py:584 plugins/train/_config.py:593 +#: plugins/train/_config.py:553 plugins/train/_config.py:595 +#: plugins/train/_config.py:609 plugins/train/_config.py:618 msgid "mask" msgstr "маска" -#: plugins/train/_config.py:531 +#: plugins/train/_config.py:556 msgid "" "The mask to be used for training. If you have selected 'Learn Mask' or " "'Penalized Mask Loss' you must select a value other than 'none'. The " @@ -840,7 +875,7 @@ msgstr "" "сообщества и для дальнейшего описания нуждается в тестировании. Профильные " "лица могут иметь низкую производительность." -#: plugins/train/_config.py:572 +#: plugins/train/_config.py:597 msgid "" "Apply gaussian blur to the mask input. This has the effect of smoothing the " "edges of the mask, which can help with poorly calculated masks and give less " @@ -856,7 +891,7 @@ msgstr "" "должно быть нечетным, если передано четное число, то оно будет округлено до " "следующего нечетного числа." -#: plugins/train/_config.py:586 +#: plugins/train/_config.py:611 msgid "" "Sets pixels that are near white to white and near black to black. Set to 0 " "for off." @@ -864,7 +899,7 @@ msgstr "" "Устанавливает пиксели, которые почти белые - в белые и которые почти черные " "- в черные. Установите 0, чтобы выключить." -#: plugins/train/_config.py:595 +#: plugins/train/_config.py:620 msgid "" "Dedicate a portion of the model to learning how to duplicate the input mask. " "Increases VRAM usage in exchange for learning a quick ability to try to " diff --git a/plugins/convert/mask/mask_blend.py b/plugins/convert/mask/mask_blend.py index bcc1ac13..37533404 100644 --- a/plugins/convert/mask/mask_blend.py +++ b/plugins/convert/mask/mask_blend.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 """ Plugin to blend the edges of the face between the swap and the original face. """ import logging -import sys -from typing import List, Optional, Tuple +import typing as T import cv2 import numpy as np @@ -11,12 +10,6 @@ from lib.align import BlurMask, DetectedFace from lib.config import FaceswapConfig from plugins.convert._config import Config -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - - logger = logging.getLogger(__name__) @@ -44,8 +37,8 @@ class Mask(): # pylint:disable=too-few-public-methods mask_type: str, output_size: int, coverage_ratio: float, - configfile: Optional[str] = None, - config: Optional[FaceswapConfig] = None) -> None: + configfile: str | None = None, + config: FaceswapConfig | None = None) -> None: logger.debug("Initializing %s: (mask_type: '%s', output_size: %s, coverage_ratio: %s, " "configfile: %s, config: %s)", self.__class__.__name__, mask_type, coverage_ratio, output_size, configfile, config) @@ -61,8 +54,8 @@ class Mask(): # pylint:disable=too-few-public-methods self._do_erode = any(amount != 0 for amount in self._erodes) def _set_config(self, - configfile: Optional[str], - config: Optional[FaceswapConfig]) -> dict: + configfile: str | None, + config: FaceswapConfig | None) -> dict: """ Set the correct configuration for the plugin based on whether a config file or pre-loaded config has been passed in. @@ -123,8 +116,8 @@ class Mask(): # pylint:disable=too-few-public-methods detected_face: DetectedFace, source_offset: np.ndarray, target_offset: np.ndarray, - centering: Literal["legacy", "face", "head"], - predicted_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: + centering: T.Literal["legacy", "face", "head"], + predicted_mask: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: """ Obtain the requested mask type and perform any defined mask manipulations. Parameters @@ -171,8 +164,8 @@ class Mask(): # pylint:disable=too-few-public-methods def _get_mask(self, detected_face: DetectedFace, - predicted_mask: Optional[np.ndarray], - centering: Literal["legacy", "face", "head"], + predicted_mask: np.ndarray | None, + centering: T.Literal["legacy", "face", "head"], source_offset: np.ndarray, target_offset: np.ndarray) -> np.ndarray: """ Return the requested mask with any requested blurring applied. @@ -229,7 +222,7 @@ class Mask(): # pylint:disable=too-few-public-methods def _get_stored_mask(self, detected_face: DetectedFace, - centering: Literal["legacy", "face", "head"], + centering: T.Literal["legacy", "face", "head"], source_offset: np.ndarray, target_offset: np.ndarray) -> np.ndarray: """ get the requested stored mask from the detected face object. @@ -303,7 +296,7 @@ class Mask(): # pylint:disable=too-few-public-methods return eroded[..., None] - def _get_erosion_kernels(self, mask: np.ndarray) -> List[np.ndarray]: + def _get_erosion_kernels(self, mask: np.ndarray) -> list[np.ndarray]: """ Get the erosion kernels for each of the center, left, top right and bottom erosions. An approximation is made based on the number of positive pixels within the mask to create diff --git a/plugins/convert/writer/_base.py b/plugins/convert/writer/_base.py index c68fae93..d283e100 100644 --- a/plugins/convert/writer/_base.py +++ b/plugins/convert/writer/_base.py @@ -4,8 +4,7 @@ import logging import os import re - -from typing import Any, List, Optional +import typing as T import numpy as np @@ -14,7 +13,7 @@ from plugins.convert._config import Config logger = logging.getLogger(__name__) # pylint: disable=invalid-name -def get_config(plugin_name: str, configfile: Optional[str] = None) -> dict: +def get_config(plugin_name: str, configfile: str | None = None) -> dict: """ Obtain the configuration settings for the writer plugin. Parameters @@ -44,7 +43,7 @@ class Output(): The full path to a custom configuration ini file. If ``None`` is passed then the file is loaded from the default location. Default: ``None``. """ - def __init__(self, output_folder: str, configfile: Optional[str] = None) -> None: + def __init__(self, output_folder: str, configfile: str | None = None) -> None: logger.debug("Initializing %s: (output_folder: '%s')", self.__class__.__name__, output_folder) self.config: dict = get_config(".".join(self.__module__.split(".")[-2:]), @@ -69,7 +68,7 @@ class Output(): retval = hasattr(self, "frame_order") return retval - def output_filename(self, filename: str, separate_mask: bool = False) -> List[str]: + def output_filename(self, filename: str, separate_mask: bool = False) -> list[str]: """ Obtain the full path for the output file, including the correct extension, for the given input filename. @@ -124,7 +123,7 @@ class Output(): logger.trace("Added to cache. Frame no: %s", frame_no) # type: ignore logger.trace("Current cache: %s", sorted(self.cache.keys())) # type:ignore - def write(self, filename: str, image: Any) -> None: + def write(self, filename: str, image: T.Any) -> None: """ Override for specific frame writing method. Parameters @@ -137,7 +136,7 @@ class Output(): """ raise NotImplementedError - def pre_encode(self, image: np.ndarray) -> Any: # pylint: disable=unused-argument + def pre_encode(self, image: np.ndarray) -> T.Any: # pylint: disable=unused-argument """ Some writer plugins support the pre-encoding of images prior to saving out. As patching is done in multiple threads, but writing is done in a single thread, it can speed up the process to do any pre-encoding as part of the converter process. diff --git a/plugins/convert/writer/ffmpeg.py b/plugins/convert/writer/ffmpeg.py index 283a581f..8e3e4b16 100644 --- a/plugins/convert/writer/ffmpeg.py +++ b/plugins/convert/writer/ffmpeg.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 """ Video output writer for faceswap.py converter """ +from __future__ import annotations import os +import typing as T + from math import ceil from subprocess import CalledProcessError, check_output, STDOUT -from typing import cast, Generator, List, Optional, Tuple import imageio import imageio_ffmpeg as im_ffm @@ -11,6 +13,9 @@ import numpy as np from ._base import Output, logger +if T.TYPE_CHECKING: + from collections.abc import Generator + class Writer(Output): """ Video output writer using imageio-ffmpeg. @@ -32,7 +37,7 @@ class Writer(Output): def __init__(self, output_folder: str, total_count: int, - frame_ranges: Optional[List[Tuple[int, int]]], + frame_ranges: list[tuple[int, int]] | None, source_video: str, **kwargs) -> None: super().__init__(output_folder, **kwargs) @@ -40,11 +45,11 @@ class Writer(Output): total_count, frame_ranges, source_video) self._source_video: str = source_video self._output_filename: str = self._get_output_filename() - self._frame_ranges: Optional[List[Tuple[int, int]]] = frame_ranges - self.frame_order: List[int] = self._set_frame_order(total_count) - self._output_dimensions: Optional[str] = None # Fix dims on 1st received frame + self._frame_ranges: list[tuple[int, int]] | None = frame_ranges + self.frame_order: list[int] = self._set_frame_order(total_count) + self._output_dimensions: str | None = None # Fix dims on 1st received frame # Need to know dimensions of first frame, so set writer then - self._writer: Optional[Generator[None, np.ndarray, None]] = None + self._writer: Generator[None, np.ndarray, None] | None = None @property def _valid_tunes(self) -> dict: @@ -63,7 +68,7 @@ class Writer(Output): return retval @property - def _output_params(self) -> List[str]: + def _output_params(self) -> list[str]: """ list: The FFMPEG Output parameters """ codec = self.config["codec"] tune = self.config["tune"] @@ -86,11 +91,11 @@ class Writer(Output): return output_args @property - def _audio_codec(self) -> Optional[str]: + def _audio_codec(self) -> str | None: """ str or ``None``: The audio codec to use. This will either be ``"copy"`` (the default) or ``None`` if skip muxing has been selected in configuration options, or if frame ranges have been passed in the command line arguments. """ - retval: Optional[str] = "copy" + retval: str | None = "copy" if self.config["skip_mux"]: logger.info("Skipping audio muxing due to configuration settings.") retval = None @@ -169,7 +174,7 @@ class Writer(Output): logger.info("Outputting to: '%s'", retval) return retval - def _set_frame_order(self, total_count: int) -> List[int]: + def _set_frame_order(self, total_count: int) -> list[int]: """ Obtain the full list of frames to be converted in order. Parameters @@ -191,7 +196,7 @@ class Writer(Output): logger.debug("frame_order: %s", retval) return retval - def _get_writer(self, frame_dims: Tuple[int, int]) -> Generator[None, np.ndarray, None]: + def _get_writer(self, frame_dims: tuple[int, int]) -> Generator[None, np.ndarray, None]: """ Add the requested encoding options and return the writer. Parameters @@ -238,13 +243,13 @@ class Writer(Output): logger.trace("Received frame: (filename: '%s', shape: %s", # type:ignore[attr-defined] filename, image.shape) if not self._output_dimensions: - input_dims = cast(Tuple[int, int], image.shape[:2]) + input_dims = T.cast(tuple[int, int], image.shape[:2]) self._set_dimensions(input_dims) self._writer = self._get_writer(input_dims) self.cache_frame(filename, image) self._save_from_cache() - def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None: + def _set_dimensions(self, frame_dims: tuple[int, int]) -> None: """ Set the attribute :attr:`_output_dimensions` based on the first frame received. This protects against different sized images coming in and ensures all images are written to ffmpeg at the same size. Dimensions are mapped to a macro block size 8. diff --git a/plugins/convert/writer/gif.py b/plugins/convert/writer/gif.py index 3727ee7c..bfa81320 100644 --- a/plugins/convert/writer/gif.py +++ b/plugins/convert/writer/gif.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 """ Animated GIF writer for faceswap.py converter """ +from __future__ import annotations import os -from typing import Optional, List, Tuple, TYPE_CHECKING +import typing as T import cv2 import imageio from ._base import Output, logger -if TYPE_CHECKING: +if T.TYPE_CHECKING: from imageio.core import format as im_format # noqa:F401 @@ -31,15 +32,16 @@ class Writer(Output): def __init__(self, output_folder: str, total_count: int, - frame_ranges: Optional[List[Tuple[int, int]]], + frame_ranges: list[tuple[int, int]] | None, **kwargs) -> None: logger.debug("total_count: %s, frame_ranges: %s", total_count, frame_ranges) super().__init__(output_folder, **kwargs) - self.frame_order: List[int] = self._set_frame_order(total_count, frame_ranges) - self._output_dimensions: Optional[Tuple[int, int]] = None # Fix dims on 1st received frame + self.frame_order: list[int] = self._set_frame_order(total_count, frame_ranges) + # Fix dims on 1st received frame + self._output_dimensions: tuple[int, int] | None = None # Need to know dimensions of first frame, so set writer then - self._writer: Optional[imageio.plugins.pillowmulti.GIFFormat.Writer] = None - self._gif_file: Optional[str] = None # Set filename based on first file seen + self._writer: imageio.plugins.pillowmulti.GIFFormat.Writer | None = None + self._gif_file: str | None = None # Set filename based on first file seen @property def _gif_params(self) -> dict: @@ -50,7 +52,7 @@ class Writer(Output): @staticmethod def _set_frame_order(total_count: int, - frame_ranges: Optional[List[Tuple[int, int]]]) -> List[int]: + frame_ranges: list[tuple[int, int]] | None) -> list[int]: """ Obtain the full list of frames to be converted in order. Parameters @@ -75,7 +77,7 @@ class Writer(Output): logger.debug("frame_order: %s", retval) return retval - def _get_writer(self) -> "im_format.Format.Writer": + def _get_writer(self) -> im_format.Format.Writer: """ Obtain the GIF writer with the requested GIF encoding options. Returns @@ -145,7 +147,7 @@ class Writer(Output): self._gif_file = retval logger.info("Outputting to: '%s'", self._gif_file) - def _set_dimensions(self, frame_dims: Tuple[int, int]) -> None: + def _set_dimensions(self, frame_dims: tuple[int, int]) -> None: """ Set the attribute :attr:`_output_dimensions` based on the first frame received. This protects against different sized images coming in and ensure all images get written to the Gif at the sema dimensions. """ diff --git a/plugins/convert/writer/opencv.py b/plugins/convert/writer/opencv.py index 2f3b91ec..d18d71a6 100644 --- a/plugins/convert/writer/opencv.py +++ b/plugins/convert/writer/opencv.py @@ -2,8 +2,6 @@ """ Image output writer for faceswap.py converter Uses cv2 for writing as in testing this was a lot faster than both Pillow and ImageIO """ -from typing import List, Tuple - import cv2 import numpy as np @@ -37,7 +35,7 @@ class Writer(Output): "transparency. Changing output format to 'png'") self.config["format"] = "png" - def _get_save_args(self) -> Tuple[int, ...]: + def _get_save_args(self) -> tuple[int, ...]: """ Obtain the save parameters for the file format. Returns @@ -46,7 +44,7 @@ class Writer(Output): The OpenCV specific arguments for the selected file format """ filetype = self.config["format"] - args: Tuple[int, ...] = tuple() + args: tuple[int, ...] = tuple() if filetype == "jpg" and self.config["jpg_quality"] > 0: args = (cv2.IMWRITE_JPEG_QUALITY, self.config["jpg_quality"]) @@ -56,7 +54,7 @@ class Writer(Output): logger.debug(args) return args - def write(self, filename: str, image: List[bytes]) -> None: + def write(self, filename: str, image: list[bytes]) -> None: """ Write out the pre-encoded image to disk. If separate mask has been selected, write out the encoded mask to a sub-folder in the output directory. @@ -77,7 +75,7 @@ class Writer(Output): except Exception as err: # pylint: disable=broad-except logger.error("Failed to save image '%s'. Original Error: %s", filename, err) - def pre_encode(self, image: np.ndarray) -> List[bytes]: + def pre_encode(self, image: np.ndarray) -> list[bytes]: """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker. Parameters diff --git a/plugins/convert/writer/pillow.py b/plugins/convert/writer/pillow.py index 92eb4a08..a9ffb0aa 100644 --- a/plugins/convert/writer/pillow.py +++ b/plugins/convert/writer/pillow.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 """ Image output writer for faceswap.py converter """ - -from typing import Dict, List, Union from io import BytesIO from PIL import Image @@ -25,7 +23,7 @@ class Writer(Output): super().__init__(output_folder, **kwargs) self._check_transparency_format() # Correct format namings for writing to byte stream - self._format_dict = dict(jpg="JPEG", jp2="JPEG 2000", tif="TIFF") + self._format_dict = {"jpg": "JPEG", "jp2": "JPEG 2000", "tif": "TIFF"} self._separate_mask = self.config["draw_transparent"] and self.config["separate_mask"] self._kwargs = self._get_save_kwargs() @@ -38,7 +36,7 @@ class Writer(Output): "transparency. Changing output format to 'png'") self.config["format"] = "png" - def _get_save_kwargs(self) -> Dict[str, Union[bool, int, str]]: + def _get_save_kwargs(self) -> dict[str, bool | int | str]: """ Return the save parameters for the file format Returns @@ -59,7 +57,7 @@ class Writer(Output): logger.debug(kwargs) return kwargs - def write(self, filename: str, image: List[BytesIO]) -> None: + def write(self, filename: str, image: list[BytesIO]) -> None: """ Write out the pre-encoded image to disk. If separate mask has been selected, write out the encoded mask to a sub-folder in the output directory. @@ -80,7 +78,7 @@ class Writer(Output): except Exception as err: # pylint: disable=broad-except logger.error("Failed to save image '%s'. Original Error: %s", filename, err) - def pre_encode(self, image: np.ndarray) -> List[BytesIO]: + def pre_encode(self, image: np.ndarray) -> list[BytesIO]: """ Pre_encode the image in lib/convert.py threads as it is a LOT quicker Parameters diff --git a/plugins/extract/_base.py b/plugins/extract/_base.py index efb95b7e..4abe8a5e 100644 --- a/plugins/extract/_base.py +++ b/plugins/extract/_base.py @@ -2,12 +2,11 @@ """ Base class for Faceswap :mod:`~plugins.extract.detect`, :mod:`~plugins.extract.align` and :mod:`~plugins.extract.mask` Plugins """ +from __future__ import annotations import logging -import sys +import typing as T from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, Generator, List, Optional, - Sequence, Union, Tuple, TYPE_CHECKING) import numpy as np from tensorflow.python.framework import errors_impl as tf_errors # pylint:disable=no-name-in-module # noqa @@ -18,12 +17,8 @@ from lib.utils import GetModel, FaceswapError from ._config import Config from .pipeline import ExtractMedia -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Callable, Generator, Sequence from queue import Queue import cv2 from lib.align import DetectedFace @@ -37,7 +32,7 @@ logger = logging.getLogger(__name__) # TODO Run with warnings mode -def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str, Any]: +def _get_config(plugin_name: str, configfile: str | None = None) -> dict[str, T.Any]: """ Return the configuration for the requested model Parameters @@ -56,7 +51,7 @@ def _get_config(plugin_name: str, configfile: Optional[str] = None) -> Dict[str, return Config(plugin_name, configfile=configfile).config_dict -BatchType = Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"] +BatchType = T.Union["DetectorBatch", "AlignerBatch", "MaskerBatch", "RecogBatch"] @dataclass @@ -84,13 +79,12 @@ class ExtractorBatch: data: dict Any specific data required during the processing phase for a particular plugin """ - image: List[np.ndarray] = field(default_factory=list) - detected_faces: Sequence[Union["DetectedFace", - List["DetectedFace"]]] = field(default_factory=list) - filename: List[str] = field(default_factory=list) + image: list[np.ndarray] = field(default_factory=list) + detected_faces: Sequence[DetectedFace | list[DetectedFace]] = field(default_factory=list) + filename: list[str] = field(default_factory=list) feed: np.ndarray = np.array([]) prediction: np.ndarray = np.array([]) - data: List[Dict[str, Any]] = field(default_factory=list) + data: list[dict[str, T.Any]] = field(default_factory=list) class Extractor(): @@ -157,10 +151,10 @@ class Extractor(): """ def __init__(self, - git_model_id: Optional[int] = None, - model_filename: Optional[Union[str, List[str]]] = None, - exclude_gpus: Optional[List[int]] = None, - configfile: Optional[str] = None, + git_model_id: int | None = None, + model_filename: str | list[str] | None = None, + exclude_gpus: list[int] | None = None, + configfile: str | None = None, instance: int = 0) -> None: logger.debug("Initializing %s: (git_model_id: %s, model_filename: %s, exclude_gpus: %s, " "configfile: %s, instance: %s, )", self.__class__.__name__, git_model_id, @@ -176,9 +170,9 @@ class Extractor(): be a list of strings """ # << SET THE FOLLOWING IN PLUGINS __init__ IF DIFFERENT FROM DEFAULT >> # - self.name: Optional[str] = None + self.name: str | None = None self.input_size = 0 - self.color_format: Literal["BGR", "RGB", "GRAY"] = "BGR" + self.color_format: T.Literal["BGR", "RGB", "GRAY"] = "BGR" self.vram = 0 self.vram_warnings = 0 # Will run at this with warnings self.vram_per_batch = 0 @@ -187,7 +181,7 @@ class Extractor(): self.queue_size = 1 """ int: Queue size for all internal queues. Set in :func:`initialize()` """ - self.model: Optional[Union["KSession", "cv2.dnn.Net"]] = None + self.model: KSession | cv2.dnn.Net | None = None """varies: The model for this plugin. Set in the plugin's :func:`init_model()` method """ # For detectors that support batching, this should be set to the calculated batch size @@ -196,26 +190,26 @@ class Extractor(): """ int: Batchsize for feeding this model. The number of images the model should feed through at once. """ - self._queues: Dict[str, "Queue"] = {} + self._queues: dict[str, Queue] = {} """ dict: in + out queues and internal queues for this plugin, """ - self._threads: List[MultiThread] = [] + self._threads: list[MultiThread] = [] """ list: Internal threads for this plugin """ - self._extract_media: Dict[str, ExtractMedia] = {} + self._extract_media: dict[str, ExtractMedia] = {} """ dict: The :class:`plugins.extract.pipeline.ExtractMedia` objects currently being processed. Stored at input for pairing back up on output of extractor process """ # << THE FOLLOWING PROTECTED ATTRIBUTES ARE SET IN PLUGIN TYPE _base.py >>> # - self._plugin_type: Optional[Literal["align", "detect", "recognition", "mask"]] = None + self._plugin_type: T.Literal["align", "detect", "recognition", "mask"] | None = None """ str: Plugin type. ``detect`, ``align``, ``recognise`` or ``mask`` set in ``._base`` """ # << Objects for splitting frame's detected faces and rejoining them >> # << for post-detector pliugins >> - self._faces_per_filename: Dict[str, int] = {} # Tracking for recompiling batches - self._rollover: Optional[ExtractMedia] = None # batch rollover items - self._output_faces: List["DetectedFace"] = [] # Recompiled output faces from plugin + self._faces_per_filename: dict[str, int] = {} # Tracking for recompiling batches + self._rollover: ExtractMedia | None = None # batch rollover items + self._output_faces: list[DetectedFace] = [] # Recompiled output faces from plugin logger.debug("Initialized _base %s", self.__class__.__name__) @@ -361,7 +355,7 @@ class Extractor(): """ raise NotImplementedError - def get_batch(self, queue: "Queue") -> Tuple[bool, BatchType]: + def get_batch(self, queue: Queue) -> tuple[bool, BatchType]: """ **Override method** (at `` level) This method should be overridden at the `` level (IE. @@ -403,7 +397,7 @@ class Extractor(): for thread in self._threads: thread.check_and_raise_error() - def rollover_collector(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia]: + def rollover_collector(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia: """ For extractors after the Detectors, the number of detected faces per frame vs extractor batch size mean that faces will need to be split/re-joined with frames. The rollover collector can be used to rollover items that don't fit in a batch. @@ -425,7 +419,7 @@ class Extractor(): if self._rollover is not None: logger.trace("Getting from _rollover: (filename: `%s`, faces: %s)", # type:ignore self._rollover.filename, len(self._rollover.detected_faces)) - item: Union[Literal["EOF"], ExtractMedia] = self._rollover + item: T.Literal["EOF"] | ExtractMedia = self._rollover self._rollover = None else: next_item = self._get_item(queue) @@ -442,9 +436,8 @@ class Extractor(): # <<< INIT METHODS >>> # @classmethod def _get_model(cls, - git_model_id: Optional[int], - model_filename: Optional[Union[str, List[str]]] - ) -> Optional[Union[str, List[str]]]: + git_model_id: int | None, + model_filename: str | list[str] | None) -> str | list[str] | None: """ Check if model is available, if not, download and unzip it """ if model_filename is None: logger.debug("No model_filename specified. Returning None") @@ -496,9 +489,9 @@ class Extractor(): self.name, self._plugin_type.title(), self.batchsize) def _add_queues(self, - in_queue: "Queue", - out_queue: "Queue", - queues: List[str]) -> None: + in_queue: Queue, + out_queue: Queue, + queues: list[str]) -> None: """ Add the queues in_queue and out_queue should be previously created queue manager queues. queues should be a list of queue names """ @@ -533,8 +526,8 @@ class Extractor(): def _add_thread(self, name: str, function: Callable[[BatchType], BatchType], - in_queue: "Queue", - out_queue: "Queue") -> None: + in_queue: Queue, + out_queue: Queue) -> None: """ Add a MultiThread thread to self._threads """ logger.debug("Adding thread: (name: %s, function: %s, in_queue: %s, out_queue: %s)", name, function, in_queue, out_queue) @@ -546,8 +539,8 @@ class Extractor(): logger.debug("Added thread: %s", name) def _obtain_batch_item(self, function: Callable[[BatchType], BatchType], - in_queue: "Queue", - out_queue: "Queue") -> Optional[BatchType]: + in_queue: Queue, + out_queue: Queue) -> BatchType | None: """ Obtain the batch item from the in queue for the current process. Parameters @@ -564,7 +557,7 @@ class Extractor(): :class:`ExtractorBatch` or ``None`` The batch, if one exists, or ``None`` if queue is exhausted """ - batch: Union[Literal["EOF"], BatchType, ExtractMedia] + batch: T.Literal["EOF"] | BatchType | ExtractMedia if function.__name__ == "_process_input": # Process input items to batches exhausted, batch = self.get_batch(in_queue) if exhausted: @@ -585,8 +578,8 @@ class Extractor(): def _thread_process(self, function: Callable[[BatchType], BatchType], - in_queue: "Queue", - out_queue: "Queue") -> None: + in_queue: Queue, + out_queue: Queue) -> None: """ Perform a plugin function in a thread Parameters @@ -629,7 +622,7 @@ class Extractor(): out_queue.put("EOF") # <<< QUEUE METHODS >>> # - def _get_item(self, queue: "Queue") -> Union[Literal["EOF"], ExtractMedia, BatchType]: + def _get_item(self, queue: Queue) -> T.Literal["EOF"] | ExtractMedia | BatchType: """ Yield one item from a queue """ item = queue.get() if isinstance(item, ExtractMedia): diff --git a/plugins/extract/align/_base/aligner.py b/plugins/extract/align/_base/aligner.py index 9f15dbef..75dae9bb 100644 --- a/plugins/extract/align/_base/aligner.py +++ b/plugins/extract/align/_base/aligner.py @@ -12,12 +12,12 @@ For each source item, the plugin must pass a dict to finalize containing: >>> "landmarks": [list of 68 point face landmarks] >>> "detected_faces": []} """ +from __future__ import annotations import logging -import sys +import typing as T from dataclasses import dataclass, field from time import sleep -from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING import cv2 import numpy as np @@ -28,12 +28,8 @@ from lib.utils import FaceswapError from plugins.extract._base import BatchType, Extractor, ExtractMedia, ExtractorBatch from .processing import AlignedFilter, ReAlign -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from queue import Queue from lib.align import DetectedFace from lib.align.aligned_face import CenteringType @@ -77,9 +73,9 @@ class AlignerBatch(ExtractorBatch): The masks used to filter out re-feed values for passing to the re-aligner. """ batch_id: int = 0 - detected_faces: List["DetectedFace"] = field(default_factory=list) + detected_faces: list[DetectedFace] = field(default_factory=list) landmarks: np.ndarray = np.array([]) - refeeds: List[np.ndarray] = field(default_factory=list) + refeeds: list[np.ndarray] = field(default_factory=list) second_pass: bool = False second_pass_masks: np.ndarray = np.array([]) @@ -142,11 +138,11 @@ class Aligner(Extractor): # pylint:disable=abstract-method """ def __init__(self, - git_model_id: Optional[int] = None, - model_filename: Optional[str] = None, - configfile: Optional[str] = None, + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, instance: int = 0, - normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None, re_feed: int = 0, re_align: bool = False, disable_filter: bool = False, @@ -160,9 +156,9 @@ class Aligner(Extractor): # pylint:disable=abstract-method instance=instance, **kwargs) self._plugin_type = "align" - self.realign_centering: "CenteringType" = "face" # overide for plugin specific centering + self.realign_centering: CenteringType = "face" # overide for plugin specific centering self._eof_seen = False - self._normalize_method: Optional[Literal["clahe", "hist", "mean"]] = None + self._normalize_method: T.Literal["clahe", "hist", "mean"] | None = None self._re_feed = re_feed self._filter = AlignedFilter(feature_filter=self.config["aligner_features"], min_scale=self.config["aligner_min_scale"], @@ -181,8 +177,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method logger.debug("Initialized %s", self.__class__.__name__) - def set_normalize_method(self, - method: Optional[Literal["none", "clahe", "hist", "mean"]]) -> None: + def set_normalize_method(self, method: T.Literal["none", "clahe", "hist", "mean"] | None + ) -> None: """ Set the normalization method for feeding faces into the aligner. Parameters @@ -191,14 +187,14 @@ class Aligner(Extractor): # pylint:disable=abstract-method The normalization method to apply to faces prior to feeding into the model """ method = None if method is None or method.lower() == "none" else method - self._normalize_method = cast(Optional[Literal["clahe", "hist", "mean"]], method) + self._normalize_method = T.cast(T.Literal["clahe", "hist", "mean"] | None, method) def initialize(self, *args, **kwargs) -> None: """ Add a call to add model input size to the re-aligner """ self._re_align.set_input_size_and_centering(self.input_size, self.realign_centering) super().initialize(*args, **kwargs) - def _handle_realigns(self, queue: "Queue") -> Optional[Tuple[bool, AlignerBatch]]: + def _handle_realigns(self, queue: Queue) -> tuple[bool, AlignerBatch] | None: """ Handle any items waiting for a second pass through the aligner. If EOF has been recieved and items are still being processed through the first pass @@ -242,7 +238,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method return None - def get_batch(self, queue: "Queue") -> Tuple[bool, AlignerBatch]: + def get_batch(self, queue: Queue) -> tuple[bool, AlignerBatch]: """ Get items for inputting into the aligner from the queue in batches Items are returned from the ``queue`` in batches of @@ -548,7 +544,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method "\n3) Enable 'Single Process' mode.") raise FaceswapError(msg) from err - def _process_refeeds(self, batch: AlignerBatch) -> List[AlignerBatch]: + def _process_refeeds(self, batch: AlignerBatch) -> list[AlignerBatch]: """ Process the output for each selected re-feed Parameters @@ -562,7 +558,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method List of :class:`AlignerBatch` objects. Each object in the list contains the results for each selected re-feed """ - retval: List[AlignerBatch] = [] + retval: list[AlignerBatch] = [] if batch.second_pass: # Re-insert empty sub-patches for re-population in ReAlign for filtered out batches selected_idx = 0 @@ -605,8 +601,8 @@ class Aligner(Extractor): # pylint:disable=abstract-method return retval def _get_refeed_filter_masks(self, - subbatches: List[AlignerBatch], - original_masks: Optional[np.ndarray] = None) -> np.ndarray: + subbatches: list[AlignerBatch], + original_masks: np.ndarray | None = None) -> np.ndarray: """ Obtain the boolean mask array for masking out failed re-feed results if filter refeed has been selected @@ -663,7 +659,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method landmarks.shape) return np.ma.array(landmarks, mask=masks).mean(axis=0).data.astype("float32") - def _process_output_first_pass(self, subbatches: List[AlignerBatch]) -> Tuple[np.ndarray, + def _process_output_first_pass(self, subbatches: list[AlignerBatch]) -> tuple[np.ndarray, np.ndarray]: """ Process the output from the aligner if this is the first or only pass. @@ -696,7 +692,7 @@ class Aligner(Extractor): # pylint:disable=abstract-method return all_landmarks, masks def _process_output_second_pass(self, - subbatches: List[AlignerBatch], + subbatches: list[AlignerBatch], masks: np.ndarray) -> np.ndarray: """ Process the output from the aligner if this is the first or only pass. diff --git a/plugins/extract/align/_base/processing.py b/plugins/extract/align/_base/processing.py index 8c13171c..efdeec94 100644 --- a/plugins/extract/align/_base/processing.py +++ b/plugins/extract/align/_base/processing.py @@ -1,21 +1,16 @@ #!/usr/bin/env python3 """ Processing methods for aligner plugins """ +from __future__ import annotations import logging -import sys +import typing as T from threading import Lock -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import numpy as np from lib.align import AlignedFace -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align import DetectedFace from .aligner import AlignerBatch from lib.align.aligned_face import CenteringType @@ -72,16 +67,16 @@ class AlignedFilter(): min_scale > 0.0 or distance > 0.0 or roll > 0.0) - self._counts: Dict[str, int] = dict(features=0, - min_scale=0, - max_scale=0, - distance=0, - roll=0) + self._counts: dict[str, int] = {"features": 0, + "min_scale": 0, + "max_scale": 0, + "distance": 0, + "roll": 0} logger.debug("Initialized %s: ", self.__class__.__name__) def _scale_test(self, face: AlignedFace, - minimum_dimension: int) -> Optional[Literal["min", "max"]]: + minimum_dimension: int) -> T.Literal["min", "max"] | None: """ Test if a face is below or above the min/max size thresholds. Returns as soon as a test fails. @@ -116,9 +111,9 @@ class AlignedFilter(): def _handle_filtered(self, key: str, - face: "DetectedFace", - faces: List["DetectedFace"], - sub_folders: List[Optional[str]], + face: DetectedFace, + faces: list[DetectedFace], + sub_folders: list[str | None], sub_folder_index: int) -> None: """ Add the filtered item to the filter counts. @@ -145,8 +140,8 @@ class AlignedFilter(): faces.append(face) sub_folders[sub_folder_index] = f"_align_filt_{key}" - def __call__(self, faces: List["DetectedFace"], minimum_dimension: int - ) -> Tuple[List["DetectedFace"], List[Optional[str]]]: + def __call__(self, faces: list[DetectedFace], minimum_dimension: int + ) -> tuple[list[DetectedFace], list[str | None]]: """ Apply the filter to the incoming batch Parameters @@ -165,11 +160,11 @@ class AlignedFilter(): List of ``Nones`` if saving filtered faces has not been selected or list of ``Nones`` and sub folder names corresponding the filtered face location """ - sub_folders: List[Optional[str]] = [None for _ in range(len(faces))] + sub_folders: list[str | None] = [None for _ in range(len(faces))] if not self._active: return faces, sub_folders - retval: List["DetectedFace"] = [] + retval: list[DetectedFace] = [] for idx, face in enumerate(faces): aligned = AlignedFace(landmarks=face.landmarks_xy, centering="face") @@ -194,8 +189,8 @@ class AlignedFilter(): return retval, sub_folders def filtered_mask(self, - batch: "AlignerBatch", - skip: Optional[Union[np.ndarray, List[int]]] = None) -> np.ndarray: + batch: AlignerBatch, + skip: np.ndarray | list[int] | None = None) -> np.ndarray: """ Obtain a list of boolean values for the given batch indicating whether they pass the filter test. @@ -262,13 +257,14 @@ class ReAlign(): self._active = active self._do_refeeds = do_refeeds self._do_filter = do_filter - self._centering: "CenteringType" = "face" + self._centering: CenteringType = "face" self._size = 0 self._tracked_lock = Lock() - self._tracked_batchs: Dict[int, Dict[Literal["filtered_landmarks"], List[np.ndarray]]] = {} + self._tracked_batchs: dict[int, + dict[T.Literal["filtered_landmarks"], list[np.ndarray]]] = {} # TODO. Probably does not need to be a list, just alignerbatch self._queue_lock = Lock() - self._queued: List["AlignerBatch"] = [] + self._queued: list[AlignerBatch] = [] logger.debug("Initialized %s", self.__class__.__name__) @property @@ -301,7 +297,7 @@ class ReAlign(): with self._tracked_lock: return bool(self._tracked_batchs) - def set_input_size_and_centering(self, input_size: int, centering: "CenteringType") -> None: + def set_input_size_and_centering(self, input_size: int, centering: CenteringType) -> None: """ Set the input size of the loaded plugin once the model has been loaded Parameters @@ -344,7 +340,7 @@ class ReAlign(): with self._tracked_lock: del self._tracked_batchs[batch_id] - def add_batch(self, batch: "AlignerBatch") -> None: + def add_batch(self, batch: AlignerBatch) -> None: """ Add first pass alignments to the queue for picking up for re-alignment, update their :attr:`second_pass` attribute to ``True`` and clear attributes not required. @@ -362,7 +358,7 @@ class ReAlign(): batch.data = [] self._queued.append(batch) - def get_batch(self) -> "AlignerBatch": + def get_batch(self) -> AlignerBatch: """ Retrieve the next batch currently queued for re-alignment Returns @@ -376,7 +372,7 @@ class ReAlign(): retval.filename) return retval - def process_batch(self, batch: "AlignerBatch") -> List[np.ndarray]: + def process_batch(self, batch: AlignerBatch) -> list[np.ndarray]: """ Pre process a batch object for re-aligning through the aligner. Parameters @@ -391,8 +387,8 @@ class ReAlign(): """ logger.trace("Processing batch: %s, landmarks: %s", # type: ignore[attr-defined] batch.filename, [b.shape for b in batch.landmarks]) - retval: List[np.ndarray] = [] - filtered_landmarks: List[np.ndarray] = [] + retval: list[np.ndarray] = [] + filtered_landmarks: list[np.ndarray] = [] for landmarks, masks in zip(batch.landmarks, batch.second_pass_masks): if not np.all(masks): # At least one face has not already been filtered aligned_faces = [AlignedFace(lms, @@ -415,7 +411,7 @@ class ReAlign(): batch.landmarks = np.array([]) # Clear the old landmarks return retval - def _transform_to_frame(self, batch: "AlignerBatch") -> np.ndarray: + def _transform_to_frame(self, batch: AlignerBatch) -> np.ndarray: """ Transform the predicted landmarks from the aligned face image back into frame co-ordinates @@ -430,14 +426,14 @@ class ReAlign(): :class:`numpy.ndarray` The landmarks transformed to frame space """ - faces: List[AlignedFace] = batch.data[0]["aligned_faces"] + faces: list[AlignedFace] = batch.data[0]["aligned_faces"] retval = np.array([aligned.transform_points(landmarks, invert=True) for landmarks, aligned in zip(batch.landmarks, faces)]) logger.trace("Transformed points: original max: %s, " # type: ignore[attr-defined] "new max: %s", batch.landmarks.max(), retval.max()) return retval - def _re_insert_filtered(self, batch: "AlignerBatch", masks: np.ndarray) -> np.ndarray: + def _re_insert_filtered(self, batch: AlignerBatch, masks: np.ndarray) -> np.ndarray: """ Re-insert landmarks that were filtered out from the re-align process back into the landmark results @@ -473,7 +469,7 @@ class ReAlign(): return retval - def process_output(self, subbatches: List["AlignerBatch"], batch_masks: np.ndarray) -> None: + def process_output(self, subbatches: list[AlignerBatch], batch_masks: np.ndarray) -> None: """ Process the output from the re-align pass. - Transform landmarks from aligned face space to face space diff --git a/plugins/extract/align/cv2_dnn.py b/plugins/extract/align/cv2_dnn.py index 9b883c2f..44c41fb6 100644 --- a/plugins/extract/align/cv2_dnn.py +++ b/plugins/extract/align/cv2_dnn.py @@ -23,15 +23,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations import logging -from typing import cast, List, Tuple, TYPE_CHECKING +import typing as T import cv2 import numpy as np from ._base import Aligner, AlignerBatch, BatchType -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.detected_face import DetectedFace logger = logging.getLogger(__name__) @@ -89,9 +90,9 @@ class Align(Aligner): assert isinstance(batch, AlignerBatch) lfaces, roi, offsets = self.align_image(batch) batch.feed = np.array(lfaces)[..., :3] - batch.data.append(dict(roi=roi, offsets=offsets)) + batch.data.append({"roi": roi, "offsets": offsets}) - def _get_box_and_offset(self, face: "DetectedFace") -> Tuple[List[int], int]: + def _get_box_and_offset(self, face: DetectedFace) -> tuple[list[int], int]: """Obtain the bounding box and offset from a detected face. @@ -108,17 +109,17 @@ class Align(Aligner): The offset of the box (difference between half width vs height) """ - box = cast(List[int], [face.left, - face.top, - face.right, - face.bottom]) - diff_height_width = cast(int, face.height) - cast(int, face.width) + box = T.cast(list[int], [face.left, + face.top, + face.right, + face.bottom]) + diff_height_width = T.cast(int, face.height) - T.cast(int, face.width) offset = int(abs(diff_height_width / 2)) return box, offset - def align_image(self, batch: AlignerBatch) -> Tuple[List[np.ndarray], - List[List[int]], - List[Tuple[int, int]]]: + def align_image(self, batch: AlignerBatch) -> tuple[list[np.ndarray], + list[list[int]], + list[tuple[int, int]]]: """ Align the incoming image for prediction Parameters @@ -159,8 +160,8 @@ class Align(Aligner): @classmethod def move_box(cls, - box: List[int], - offset: Tuple[int, int]) -> List[int]: + box: list[int], + offset: tuple[int, int]) -> list[int]: """Move the box to direction specified by vector offset Parameters @@ -182,7 +183,7 @@ class Align(Aligner): return [left, top, right, bottom] @staticmethod - def get_square_box(box: List[int]) -> List[int]: + def get_square_box(box: list[int]) -> list[int]: """Get a square box out of the given box, by expanding it. Parameters @@ -226,7 +227,7 @@ class Align(Aligner): return [left, top, right, bottom] @classmethod - def pad_image(cls, box: List[int], image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: + def pad_image(cls, box: list[int], image: np.ndarray) -> tuple[np.ndarray, tuple[int, int]]: """Pad image if face-box falls outside of boundaries Parameters diff --git a/plugins/extract/align/fan.py b/plugins/extract/align/fan.py index 5a12610a..a829f3bc 100644 --- a/plugins/extract/align/fan.py +++ b/plugins/extract/align/fan.py @@ -3,8 +3,9 @@ Code adapted and modified from: https://github.com/1adrianb/face-alignment """ +from __future__ import annotations import logging -from typing import cast, List, TYPE_CHECKING +import typing as T import cv2 import numpy as np @@ -12,7 +13,7 @@ import numpy as np from lib.model.session import KSession from ._base import Aligner, AlignerBatch, BatchType -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align import DetectedFace logger = logging.getLogger(__name__) @@ -76,10 +77,10 @@ class Align(Aligner): logger.trace("Aligning faces around center") # type:ignore[attr-defined] center_scale = self.get_center_scale(batch.detected_faces) batch.feed = np.array(self.crop(batch, center_scale))[..., :3] - batch.data.append(dict(center_scale=center_scale)) + batch.data.append({"center_scale": center_scale}) logger.trace("Aligned image around center") # type:ignore[attr-defined] - def get_center_scale(self, detected_faces: List["DetectedFace"]) -> np.ndarray: + def get_center_scale(self, detected_faces: list[DetectedFace]) -> np.ndarray: """ Get the center and set scale of bounding box Parameters @@ -95,11 +96,11 @@ class Align(Aligner): logger.trace("Calculating center and scale") # type:ignore[attr-defined] center_scale = np.empty((len(detected_faces), 68, 3), dtype='float32') for index, face in enumerate(detected_faces): - x_center = (cast(int, face.left) + face.right) / 2.0 - y_center = (cast(int, face.top) + face.bottom) / 2.0 - cast(int, face.height) * 0.12 - scale = (cast(int, face.width) + cast(int, face.height)) * self.reference_scale - center_scale[index, :, 0] = np.full(68, x_center, dtype='float32') - center_scale[index, :, 1] = np.full(68, y_center, dtype='float32') + x_ctr = (T.cast(int, face.left) + face.right) / 2.0 + y_ctr = (T.cast(int, face.top) + face.bottom) / 2.0 - T.cast(int, face.height) * 0.12 + scale = (T.cast(int, face.width) + T.cast(int, face.height)) * self.reference_scale + center_scale[index, :, 0] = np.full(68, x_ctr, dtype='float32') + center_scale[index, :, 1] = np.full(68, y_ctr, dtype='float32') center_scale[index, :, 2] = np.full(68, scale, dtype='float32') logger.trace("Calculated center and scale: %s", center_scale) # type:ignore[attr-defined] return center_scale @@ -144,7 +145,7 @@ class Align(Aligner): dsize=(self.input_size, self.input_size), interpolation=interp) - def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> List[np.ndarray]: + def crop(self, batch: AlignerBatch, center_scale: np.ndarray) -> list[np.ndarray]: """ Crop image around the center point Parameters diff --git a/plugins/extract/detect/_base.py b/plugins/extract/detect/_base.py index 10488ed5..c85f7577 100644 --- a/plugins/extract/detect/_base.py +++ b/plugins/extract/detect/_base.py @@ -15,9 +15,11 @@ To get a :class:`~lib.align.DetectedFace` object use the function: >>> face = self._to_detected_face(, , , ) """ +from __future__ import annotations import logging +import typing as T + from dataclasses import dataclass, field -from typing import cast, Generator, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 import numpy as np @@ -30,7 +32,8 @@ from lib.utils import FaceswapError from plugins.extract._base import BatchType, Extractor, ExtractorBatch from plugins.extract.pipeline import ExtractMedia -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from queue import Queue logger = logging.getLogger(__name__) @@ -53,10 +56,10 @@ class DetectorBatch(ExtractorBatch): initial_feed: :class:`numpy.ndarray` Used to hold the initial :attr:`feed` when rotate images is enabled """ - detected_faces: List[List["DetectedFace"]] = field(default_factory=list) - rotation_matrix: List[np.ndarray] = field(default_factory=list) - scale: List[float] = field(default_factory=list) - pad: List[Tuple[int, int]] = field(default_factory=list) + detected_faces: list[list["DetectedFace"]] = field(default_factory=list) + rotation_matrix: list[np.ndarray] = field(default_factory=list) + scale: list[float] = field(default_factory=list) + pad: list[tuple[int, int]] = field(default_factory=list) initial_feed: np.ndarray = np.array([]) @@ -95,11 +98,11 @@ class Detector(Extractor): # pylint:disable=abstract-method """ def __init__(self, - git_model_id: Optional[int] = None, - model_filename: Optional[Union[str, List[str]]] = None, - configfile: Optional[str] = None, + git_model_id: int | None = None, + model_filename: str | list[str] | None = None, + configfile: str | None = None, instance: int = 0, - rotation: Optional[str] = None, + rotation: str | None = None, min_size: int = 0, **kwargs) -> None: logger.debug("Initializing %s: (rotation: %s, min_size: %s)", self.__class__.__name__, @@ -117,7 +120,7 @@ class Detector(Extractor): # pylint:disable=abstract-method logger.debug("Initialized _base %s", self.__class__.__name__) # <<< QUEUE METHODS >>> # - def get_batch(self, queue: "Queue") -> Tuple[bool, DetectorBatch]: + def get_batch(self, queue: Queue) -> tuple[bool, DetectorBatch]: """ Get items for inputting to the detector plugin in batches Items are received as :class:`~plugins.extract.pipeline.ExtractMedia` objects and converted @@ -271,7 +274,7 @@ class Detector(Extractor): # pylint:disable=abstract-method """ Wrap models predict function in rotations """ assert isinstance(batch, DetectorBatch) batch.rotation_matrix = [np.array([]) for _ in range(len(batch.feed))] - found_faces: List[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))] + found_faces: list[np.ndarray] = [np.array([]) for _ in range(len(batch.feed))] for angle in self.rotation: # Rotate the batch and insert placeholders for already found faces self._rotate_batch(batch, angle) @@ -301,9 +304,9 @@ class Detector(Extractor): # pylint:disable=abstract-method "degrees", angle) - found_faces = cast(List[np.ndarray], ([face if not found.any() else found - for face, found in zip(batch.prediction, - found_faces)])) + found_faces = T.cast(list[np.ndarray], ([face if not found.any() else found + for face, found in zip(batch.prediction, + found_faces)])) if all(face.any() for face in found_faces): logger.trace("Faces found for all images") # type:ignore[attr-defined] @@ -317,7 +320,7 @@ class Detector(Extractor): # pylint:disable=abstract-method # <<< DETECTION IMAGE COMPILATION METHODS >>> # def _compile_detection_image(self, item: ExtractMedia - ) -> Tuple[np.ndarray, float, Tuple[int, int]]: + ) -> tuple[np.ndarray, float, tuple[int, int]]: """ Compile the detection image for feeding into the model Parameters @@ -345,7 +348,7 @@ class Detector(Extractor): # pylint:disable=abstract-method image.shape, scale, pad) return image, scale, pad - def _set_scale(self, image_size: Tuple[int, int]) -> float: + def _set_scale(self, image_size: tuple[int, int]) -> float: """ Set the scale factor for incoming image Parameters @@ -362,7 +365,7 @@ class Detector(Extractor): # pylint:disable=abstract-method logger.trace("Detector scale: %s", scale) # type:ignore[attr-defined] return scale - def _set_padding(self, image_size: Tuple[int, int], scale: float) -> Tuple[int, int]: + def _set_padding(self, image_size: tuple[int, int], scale: float) -> tuple[int, int]: """ Set the image padding for non-square images Parameters @@ -382,7 +385,7 @@ class Detector(Extractor): # pylint:disable=abstract-method return pad_left, pad_top @staticmethod - def _scale_image(image: np.ndarray, image_size: Tuple[int, int], scale: float) -> np.ndarray: + def _scale_image(image: np.ndarray, image_size: tuple[int, int], scale: float) -> np.ndarray: """ Scale the image and optional pad to given size Parameters @@ -439,8 +442,8 @@ class Detector(Extractor): # pylint:disable=abstract-method return image # <<< FINALIZE METHODS >>> # - def _remove_zero_sized_faces(self, batch_faces: List[List[DetectedFace]] - ) -> List[List[DetectedFace]]: + def _remove_zero_sized_faces(self, batch_faces: list[list[DetectedFace]] + ) -> list[list[DetectedFace]]: """ Remove items from batch_faces where detected face is of zero size or face falls entirely outside of image @@ -463,8 +466,8 @@ class Detector(Extractor): # pylint:disable=abstract-method logger.trace("Output sizes: %s", [len(face) for face in retval]) # type: ignore return retval - def _filter_small_faces(self, detected_faces: List[List[DetectedFace]] - ) -> List[List[DetectedFace]]: + def _filter_small_faces(self, detected_faces: list[list[DetectedFace]] + ) -> list[list[DetectedFace]]: """ Filter out any faces smaller than the min size threshold Parameters @@ -493,7 +496,7 @@ class Detector(Extractor): # pylint:disable=abstract-method # <<< IMAGE ROTATION METHODS >>> # @staticmethod - def _get_rotation_angles(rotation: Optional[str]) -> List[int]: + def _get_rotation_angles(rotation: str | None) -> list[int]: """ Set the rotation angles. Parameters @@ -544,8 +547,8 @@ class Detector(Extractor): # pylint:disable=abstract-method batch.initial_feed = batch.feed.copy() return - feeds: List[np.ndarray] = [] - rotmats: List[np.ndarray] = [] + feeds: list[np.ndarray] = [] + rotmats: list[np.ndarray] = [] for img, faces, rotmat in zip(batch.initial_feed, batch.prediction, batch.rotation_matrix): @@ -605,7 +608,7 @@ class Detector(Extractor): # pylint:disable=abstract-method def _rotate_image_by_angle(self, image: np.ndarray, - angle: int) -> Tuple[np.ndarray, np.ndarray]: + angle: int) -> tuple[np.ndarray, np.ndarray]: """ Rotate an image by a given angle. Parameters diff --git a/plugins/extract/detect/mtcnn.py b/plugins/extract/detect/mtcnn.py index 4c14d122..8af8a41b 100644 --- a/plugins/extract/detect/mtcnn.py +++ b/plugins/extract/detect/mtcnn.py @@ -34,7 +34,7 @@ class Detect(Detector): self.kwargs = self._validate_kwargs() self.color_format = "RGB" - def _validate_kwargs(self) -> T.Dict[str, T.Union[int, float, T.List[float]]]: + def _validate_kwargs(self) -> dict[str, int | float | list[float]]: """ Validate that config options are correct. If not reset to default """ valid = True threshold = [self.config["threshold_1"], @@ -164,7 +164,7 @@ class PNet(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, cpu_mode: bool, input_size: int, min_size: int, @@ -185,10 +185,10 @@ class PNet(KSession): self._pnet_scales = self._calculate_scales(min_size, factor) self._pnet_sizes = [(int(input_size * scale), int(input_size * scale)) for scale in self._pnet_scales] - self._pnet_input: T.Optional[T.List[np.ndarray]] = None + self._pnet_input: list[np.ndarray] | None = None @staticmethod - def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: + def model_definition() -> tuple[list[Tensor], list[Tensor]]: """ Keras P-Network Definition for MTCNN """ input_ = Input(shape=(None, None, 3)) var_x = Conv2D(10, (3, 3), strides=1, padding='valid', name='conv1')(input_) @@ -204,7 +204,7 @@ class PNet(KSession): def _calculate_scales(self, minsize: int, - factor: float) -> T.List[float]: + factor: float) -> list[float]: """ Calculate multi-scale Parameters @@ -231,7 +231,7 @@ class PNet(KSession): logger.trace(scales) # type:ignore return scales - def __call__(self, images: np.ndarray) -> T.List[np.ndarray]: + def __call__(self, images: np.ndarray) -> list[np.ndarray]: """ first stage - fast proposal network (p-net) to obtain face candidates Parameters @@ -245,8 +245,8 @@ class PNet(KSession): List of face candidates from P-Net """ batch_size = images.shape[0] - rectangles: T.List[T.List[T.List[T.Union[int, float]]]] = [[] for _ in range(batch_size)] - scores: T.List[T.List[np.ndarray]] = [[] for _ in range(batch_size)] + rectangles: list[list[list[int | float]]] = [[] for _ in range(batch_size)] + scores: list[list[np.ndarray]] = [[] for _ in range(batch_size)] if self._pnet_input is None: self._pnet_input = [np.empty((batch_size, rheight, rwidth, 3), dtype="float32") @@ -278,7 +278,7 @@ class PNet(KSession): class_probabilities: np.ndarray, roi: np.ndarray, size: int, - scale: float) -> T.Tuple[np.ndarray, np.ndarray]: + scale: float) -> tuple[np.ndarray, np.ndarray]: """ Detect face position and calibrate bounding box on 12net feature map(matrix version) Parameters @@ -344,7 +344,7 @@ class RNet(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, cpu_mode: bool, input_size: int, threshold: float) -> None: @@ -360,7 +360,7 @@ class RNet(KSession): self._threshold = threshold @staticmethod - def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: + def model_definition() -> tuple[list[Tensor], list[Tensor]]: """ Keras R-Network Definition for MTCNN """ input_ = Input(shape=(24, 24, 3)) var_x = Conv2D(28, (3, 3), strides=1, padding='valid', name='conv1')(input_) @@ -383,8 +383,8 @@ class RNet(KSession): def __call__(self, images: np.ndarray, - rectangle_batch: T.List[np.ndarray], - ) -> T.List[np.ndarray]: + rectangle_batch: list[np.ndarray], + ) -> list[np.ndarray]: """ second stage - refinement of face candidates with r-net Parameters @@ -399,7 +399,7 @@ class RNet(KSession): List List of :class:`numpy.ndarray` refined face candidates from R-Net """ - ret: T.List[np.ndarray] = [] + ret: list[np.ndarray] = [] for idx, (rectangles, image) in enumerate(zip(rectangle_batch, images)): if not np.any(rectangles): ret.append(np.array([])) @@ -474,7 +474,7 @@ class ONet(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, cpu_mode: bool, input_size: int, threshold: float) -> None: @@ -490,7 +490,7 @@ class ONet(KSession): self._threshold = threshold @staticmethod - def model_definition() -> T.Tuple[T.List[Tensor], T.List[Tensor]]: + def model_definition() -> tuple[list[Tensor], list[Tensor]]: """ Keras O-Network for MTCNN """ input_ = Input(shape=(48, 48, 3)) var_x = Conv2D(32, (3, 3), strides=1, padding='valid', name='conv1')(input_) @@ -516,8 +516,8 @@ class ONet(KSession): def __call__(self, images: np.ndarray, - rectangle_batch: T.List[np.ndarray] - ) -> T.List[T.Tuple[np.ndarray, np.ndarray]]: + rectangle_batch: list[np.ndarray] + ) -> list[tuple[np.ndarray, np.ndarray]]: """ Third stage - further refinement and facial landmarks positions with o-net Parameters @@ -532,7 +532,7 @@ class ONet(KSession): List List of refined final candidates, scores and landmark points from O-Net """ - ret: T.List[T.Tuple[np.ndarray, np.ndarray]] = [] + ret: list[tuple[np.ndarray, np.ndarray]] = [] for idx, rectangles in enumerate(rectangle_batch): if not np.any(rectangles): ret.append((np.empty((0, 5)), np.empty(0))) @@ -552,7 +552,7 @@ class ONet(KSession): def _filter_face_48net(self, class_probabilities: np.ndarray, roi: np.ndarray, points: np.ndarray, - rectangles: np.ndarray) -> T.Tuple[np.ndarray, np.ndarray]: + rectangles: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Filter face position and calibrate bounding box on 12net's output Parameters @@ -623,13 +623,13 @@ class MTCNN(): # pylint: disable=too-few-public-methods Default: `0.709` """ def __init__(self, - model_path: T.List[str], + model_path: list[str], allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, cpu_mode: bool, input_size: int = 640, minsize: int = 20, - threshold: T.Optional[T.List[float]] = None, + threshold: list[float] | None = None, factor: float = 0.709) -> None: logger.debug("Initializing: %s: (model_path: '%s', allow_growth: %s, exclude_gpus: %s, " "input_size: %s, minsize: %s, threshold: %s, factor: %s)", @@ -660,7 +660,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods logger.debug("Initialized: %s", self.__class__.__name__) - def detect_faces(self, batch: np.ndarray) -> T.Tuple[np.ndarray, T.Tuple[np.ndarray]]: + def detect_faces(self, batch: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray]]: """Detects faces in an image, and returns bounding boxes and points for them. Parameters @@ -684,7 +684,7 @@ class MTCNN(): # pylint: disable=too-few-public-methods def nms(rectangles: np.ndarray, scores: np.ndarray, threshold: float, - method: str = "iom") -> T.Tuple[np.ndarray, np.ndarray]: + method: str = "iom") -> tuple[np.ndarray, np.ndarray]: """ apply non-maximum suppression on ROIs in same scale(matrix version) Parameters diff --git a/plugins/extract/detect/s3fd.py b/plugins/extract/detect/s3fd.py index 853eeed5..89d538b7 100644 --- a/plugins/extract/detect/s3fd.py +++ b/plugins/extract/detect/s3fd.py @@ -125,10 +125,10 @@ class L2Norm(keras.layers.Layer): class SliceO2K(keras.layers.Layer): """ Custom Keras Slice layer generated by onnx2keras. """ def __init__(self, - starts: T.List[int], - ends: T.List[int], - axes: T.Optional[T.List[int]] = None, - steps: T.Optional[T.List[int]] = None, + starts: list[int], + ends: list[int], + axes: list[int] | None = None, + steps: list[int] | None = None, **kwargs) -> None: self._starts = starts self._ends = ends @@ -136,7 +136,7 @@ class SliceO2K(keras.layers.Layer): self._steps = steps super().__init__(**kwargs) - def _get_slices(self, dimensions: int) -> T.List[T.Tuple[int, ...]]: + def _get_slices(self, dimensions: int) -> list[tuple[int, ...]]: """ Obtain slices for the given number of dimensions. Parameters @@ -154,7 +154,7 @@ class SliceO2K(keras.layers.Layer): assert len(axes) == len(steps) == len(self._starts) == len(self._ends) return list(zip(axes, self._starts, self._ends, steps)) - def compute_output_shape(self, input_shape: T.Tuple[int, ...]) -> T.Tuple[int, ...]: + def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: """Computes the output shape of the layer. Assumes that the layer will be built to match that input shape provided. @@ -230,7 +230,7 @@ class S3fd(KSession): model_path: str, model_kwargs: dict, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, confidence: float) -> None: logger.debug("Initializing: %s: (model_path: '%s', model_kwargs: %s, allow_growth: %s, " "exclude_gpus: %s, confidence: %s)", self.__class__.__name__, model_path, @@ -246,7 +246,7 @@ class S3fd(KSession): self.average_img = np.array([104.0, 117.0, 123.0]) logger.debug("Initialized: %s", self.__class__.__name__) - def model_definition(self) -> T.Tuple[T.List[Tensor], T.List[Tensor]]: + def model_definition(self) -> tuple[list[Tensor], list[Tensor]]: """ Keras S3FD Model Definition, adapted from FAN pytorch implementation. """ input_ = Input(shape=(640, 640, 3)) var_x = self.conv_block(input_, 64, 1, 2) @@ -396,7 +396,7 @@ class S3fd(KSession): batch = batch - self.average_img return batch - def finalize_predictions(self, bounding_boxes_scales: T.List[np.ndarray]) -> np.ndarray: + def finalize_predictions(self, bounding_boxes_scales: list[np.ndarray]) -> np.ndarray: """ Process the output from the model to obtain faces Parameters @@ -413,7 +413,7 @@ class S3fd(KSession): ret.append(finallist) return np.array(ret, dtype="object") - def _post_process(self, bboxlist: T.List[np.ndarray]) -> np.ndarray: + def _post_process(self, bboxlist: list[np.ndarray]) -> np.ndarray: """ Perform post processing on output TODO: do this on the batch. """ diff --git a/plugins/extract/mask/_base.py b/plugins/extract/mask/_base.py index 34cd0c06..837b6812 100644 --- a/plugins/extract/mask/_base.py +++ b/plugins/extract/mask/_base.py @@ -12,9 +12,11 @@ For each source item, the plugin must pass a dict to finalize containing: >>> {"filename": , >>> "detected_faces": } """ +from __future__ import annotations import logging +import typing as T + from dataclasses import dataclass, field -from typing import Generator, List, Optional, Tuple, TYPE_CHECKING import cv2 import numpy as np @@ -25,7 +27,8 @@ from lib.align import AlignedFace, transform_image from lib.utils import FaceswapError from plugins.extract._base import BatchType, Extractor, ExtractorBatch, ExtractMedia -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from queue import Queue from lib.align import DetectedFace from lib.align.aligned_face import CenteringType @@ -44,9 +47,9 @@ class MaskerBatch(ExtractorBatch): roi_masks: list The region of interest masks for the batch """ - detected_faces: List["DetectedFace"] = field(default_factory=list) - roi_masks: List[np.ndarray] = field(default_factory=list) - feed_faces: List[AlignedFace] = field(default_factory=list) + detected_faces: list[DetectedFace] = field(default_factory=list) + roi_masks: list[np.ndarray] = field(default_factory=list) + feed_faces: list[AlignedFace] = field(default_factory=list) class Masker(Extractor): # pylint:disable=abstract-method @@ -77,9 +80,9 @@ class Masker(Extractor): # pylint:disable=abstract-method """ def __init__(self, - git_model_id: Optional[int] = None, - model_filename: Optional[str] = None, - configfile: Optional[str] = None, + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, instance: int = 0, **kwargs) -> None: logger.debug("Initializing %s: (configfile: %s)", self.__class__.__name__, configfile) @@ -93,11 +96,11 @@ class Masker(Extractor): # pylint:disable=abstract-method self._plugin_type = "mask" self._storage_name = self.__module__.rsplit(".", maxsplit=1)[-1].replace("_", "-") - self._storage_centering: "CenteringType" = "face" # Centering to store the mask at + self._storage_centering: CenteringType = "face" # Centering to store the mask at self._storage_size = 128 # Size to store masks at. Leave this at default logger.debug("Initialized %s", self.__class__.__name__) - def get_batch(self, queue: "Queue") -> Tuple[bool, MaskerBatch]: + def get_batch(self, queue: Queue) -> tuple[bool, MaskerBatch]: """ Get items for inputting into the masker from the queue in batches Items are returned from the ``queue`` in batches of diff --git a/plugins/extract/mask/bisenet_fp.py b/plugins/extract/mask/bisenet_fp.py index 79781de3..cf8a177f 100644 --- a/plugins/extract/mask/bisenet_fp.py +++ b/plugins/extract/mask/bisenet_fp.py @@ -49,7 +49,7 @@ class Mask(Masker): # Separate storage for face and head masks self._storage_name = f"{self._storage_name}_{self._storage_centering}" - def _check_weights_selection(self, configfile: T.Optional[str]) -> T.Tuple[bool, int]: + def _check_weights_selection(self, configfile: str | None) -> tuple[bool, int]: """ Check which weights have been selected. This is required for passing along the correct file name for the corresponding weights @@ -73,7 +73,7 @@ class Mask(Masker): version = 1 if not is_faceswap else 2 if config.get("include_hair") else 3 return is_faceswap, version - def _get_segment_indices(self) -> T.List[int]: + def _get_segment_indices(self) -> list[int]: """ Obtain the segment indices to include within the face mask area based on user configuration settings. @@ -163,7 +163,7 @@ class Mask(Masker): # SOFTWARE. -_NAME_TRACKER: T.Set[str] = set() +_NAME_TRACKER: set[str] = set() def _get_name(name: str, start_idx: int = 1) -> str: @@ -554,7 +554,7 @@ class BiSeNet(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]], + exclude_gpus: list[int] | None, input_size: int, num_classes: int, cpu_mode: bool) -> None: @@ -569,7 +569,7 @@ class BiSeNet(KSession): self.define_model(self._model_definition) self.load_model_weights() - def _model_definition(self) -> T.Tuple[Tensor, T.List[Tensor]]: + def _model_definition(self) -> tuple[Tensor, list[Tensor]]: """ Definition of the VGG Obstructed Model. Returns diff --git a/plugins/extract/mask/components.py b/plugins/extract/mask/components.py index 0dc35e5e..6ba0b540 100644 --- a/plugins/extract/mask/components.py +++ b/plugins/extract/mask/components.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 """ Components Mask for faceswap.py """ +from __future__ import annotations import logging -from typing import List, Tuple, TYPE_CHECKING +import typing as T import cv2 import numpy as np from ._base import BatchType, Masker -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.aligned_face import AlignedFace logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ class Mask(Masker): def predict(self, feed: np.ndarray) -> np.ndarray: """ Run model to get predictions """ - faces: List["AlignedFace"] = feed[1] + faces: list[AlignedFace] = feed[1] feed = feed[0] for mask, face in zip(feed, faces): parts = self.parse_parts(np.array(face.landmarks)) @@ -51,7 +52,7 @@ class Mask(Masker): return @staticmethod - def parse_parts(landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]: + def parse_parts(landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]: """ Component face hull mask """ r_jaw = (landmarks[0:9], landmarks[17:18]) l_jaw = (landmarks[8:17], landmarks[26:27]) diff --git a/plugins/extract/mask/extended.py b/plugins/extract/mask/extended.py index fa253ba1..0755e794 100644 --- a/plugins/extract/mask/extended.py +++ b/plugins/extract/mask/extended.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """ Extended Mask for faceswap.py """ +from __future__ import annotations import logging -from typing import List, Tuple, TYPE_CHECKING +import typing as T import cv2 import numpy as np @@ -9,7 +10,7 @@ from ._base import BatchType, Masker logger = logging.getLogger(__name__) -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.aligned_face import AlignedFace @@ -35,7 +36,7 @@ class Mask(Masker): def predict(self, feed: np.ndarray) -> np.ndarray: """ Run model to get predictions """ - faces: List["AlignedFace"] = feed[1] + faces: list[AlignedFace] = feed[1] feed = feed[0] for mask, face in zip(feed, faces): parts = self.parse_parts(np.array(face.landmarks)) @@ -78,7 +79,7 @@ class Mask(Masker): landmarks[17:22] = top_l + ((top_l - bot_l) // 2) landmarks[22:27] = top_r + ((top_r - bot_r) // 2) - def parse_parts(self, landmarks: np.ndarray) -> List[Tuple[np.ndarray, ...]]: + def parse_parts(self, landmarks: np.ndarray) -> list[tuple[np.ndarray, ...]]: """ Extended face hull mask """ self._adjust_mask_top(landmarks) diff --git a/plugins/extract/mask/unet_dfl.py b/plugins/extract/mask/unet_dfl.py index 930b074c..4ca2f3dc 100644 --- a/plugins/extract/mask/unet_dfl.py +++ b/plugins/extract/mask/unet_dfl.py @@ -13,7 +13,7 @@ Model file sourced from... https://github.com/iperov/DeepFaceLab/blob/master/nnlib/FANSeg_256_full_face.h5 """ import logging -from typing import cast +import typing as T import numpy as np from lib.model.session import KSession @@ -52,7 +52,7 @@ class Mask(Masker): def process_input(self, batch: BatchType) -> None: """ Compile the detected faces for prediction """ assert isinstance(batch, MaskerBatch) - batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3] + batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3] for feed in batch.feed_faces], dtype="float32") / 255.0 logger.trace("feed shape: %s", batch.feed.shape) # type: ignore diff --git a/plugins/extract/mask/vgg_clear.py b/plugins/extract/mask/vgg_clear.py index 9ab009e1..50165f80 100644 --- a/plugins/extract/mask/vgg_clear.py +++ b/plugins/extract/mask/vgg_clear.py @@ -94,7 +94,7 @@ class VGGClear(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]]): + exclude_gpus: list[int] | None): super().__init__("VGG Obstructed", model_path, allow_growth=allow_growth, @@ -103,7 +103,7 @@ class VGGClear(KSession): self.load_model_weights() @classmethod - def _model_definition(cls) -> T.Tuple[Tensor, Tensor]: + def _model_definition(cls) -> tuple[Tensor, Tensor]: """ Definition of the VGG Obstructed Model. Returns @@ -210,7 +210,7 @@ class _ScorePool(): # pylint:disable=too-few-public-methods crop: tuple The amount of 2D cropping to apply. Tuple of `ints` """ - def __init__(self, level: int, scale: float, crop: T.Tuple[int, int]): + def __init__(self, level: int, scale: float, crop: tuple[int, int]): self._name = f"_pool{level}" self._cropping = (crop, crop) self._scale = scale diff --git a/plugins/extract/mask/vgg_obstructed.py b/plugins/extract/mask/vgg_obstructed.py index e7a5fa80..a3f543d7 100644 --- a/plugins/extract/mask/vgg_obstructed.py +++ b/plugins/extract/mask/vgg_obstructed.py @@ -90,7 +90,7 @@ class VGGObstructed(KSession): def __init__(self, model_path: str, allow_growth: bool, - exclude_gpus: T.Optional[T.List[int]]) -> None: + exclude_gpus: list[int] | None) -> None: super().__init__("VGG Obstructed", model_path, allow_growth=allow_growth, @@ -99,7 +99,7 @@ class VGGObstructed(KSession): self.load_model_weights() @classmethod - def _model_definition(cls) -> T.Tuple[Tensor, Tensor]: + def _model_definition(cls) -> tuple[Tensor, Tensor]: """ Definition of the VGG Obstructed Model. Returns diff --git a/plugins/extract/pipeline.py b/plugins/extract/pipeline.py index 70b37229..80f598ac 100644 --- a/plugins/extract/pipeline.py +++ b/plugins/extract/pipeline.py @@ -8,10 +8,9 @@ together. This module sets up a pipeline for the extraction workflow, loading detect, align and mask plugins either in parallel or in series, giving easy access to input and output. """ - +from __future__ import annotations import logging -import sys -from typing import cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union +import typing as T import cv2 @@ -20,13 +19,9 @@ from lib.queue_manager import EventQueue, queue_manager, QueueEmpty from lib.utils import get_backend from plugins.plugin_loader import PluginLoader -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: import numpy as np + from collections.abc import Generator from lib.align.alignments import PNGHeaderSourceDict from lib.align.detected_face import DetectedFace from plugins.extract._base import Extractor as PluginExtractor @@ -102,16 +97,16 @@ class Extractor(): :attr:`final_pass` to indicate to the caller which phase is being processed """ def __init__(self, - detector: Optional[str], - aligner: Optional[str], - masker: Optional[Union[str, List[str]]], - recognition: Optional[str] = None, - configfile: Optional[str] = None, + detector: str | None, + aligner: str | None, + masker: str | list[str] | None, + recognition: str | None = None, + configfile: str | None = None, multiprocess: bool = False, - exclude_gpus: Optional[List[int]] = None, - rotate_images: Optional[str] = None, + exclude_gpus: list[int] | None = None, + rotate_images: str | None = None, min_size: int = 0, - normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]] = None, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None = None, re_feed: int = 0, re_align: bool = False, disable_filter: bool = False) -> None: @@ -122,8 +117,9 @@ class Extractor(): recognition, configfile, multiprocess, exclude_gpus, rotate_images, min_size, normalize_method, re_feed, re_align, disable_filter) self._instance = _get_instance() - maskers = [cast(Optional[str], - masker)] if not isinstance(masker, list) else cast(List[Optional[str]], masker) + maskers = [T.cast(str | None, + masker)] if not isinstance(masker, list) else T.cast(list[str | None], + masker) self._flow = self._set_flow(detector, aligner, maskers, recognition) self._exclude_gpus = exclude_gpus # We only ever need 1 item in each queue. This is 2 items cached (1 in queue 1 waiting @@ -220,13 +216,13 @@ class Extractor(): return retval @property - def aligner(self) -> "Aligner": + def aligner(self) -> Aligner: """ The currently selected aligner plugin """ assert self._align is not None return self._align @property - def recognition(self) -> "Identity": + def recognition(self) -> Identity: """ The currently selected recognition plugin """ assert self._recognition is not None return self._recognition @@ -237,7 +233,7 @@ class Extractor(): self._phase_index = 0 def set_batchsize(self, - plugin_type: Literal["align", "detect"], + plugin_type: T.Literal["align", "detect"], batchsize: int) -> None: """ Set the batch size of a given :attr:`plugin_type` to the given :attr:`batchsize`. @@ -311,7 +307,7 @@ class Extractor(): # <<< INTERNAL METHODS >>> # @property - def _parallel_scaling(self) -> Dict[int, float]: + def _parallel_scaling(self) -> dict[int, float]: """ dict: key is number of parallel plugins being loaded, value is the scaling factor that the total base vram for those plugins should be scaled by @@ -335,7 +331,7 @@ class Extractor(): return retval @property - def _vram_per_phase(self) -> Dict[str, float]: + def _vram_per_phase(self) -> dict[str, float]: """ dict: The amount of vram required for each phase in :attr:`_flow`. """ retval = {} for phase in self._flow: @@ -359,7 +355,7 @@ class Extractor(): return retval @property - def _current_phase(self) -> List[str]: + def _current_phase(self) -> list[str]: """ list: The current phase from :attr:`_phases` that is running through the extractor. """ retval = self._phases[self._phase_index] logger.trace(retval) # type: ignore @@ -384,7 +380,7 @@ class Extractor(): return retval @property - def _all_plugins(self) -> List["PluginExtractor"]: + def _all_plugins(self) -> list[PluginExtractor]: """ Return list of all plugin objects in this pipeline """ retval = [] for phase in self._flow: @@ -396,7 +392,7 @@ class Extractor(): return retval @property - def _active_plugins(self) -> List["PluginExtractor"]: + def _active_plugins(self) -> list[PluginExtractor]: """ Return the plugins that are currently active based on pass """ retval = [] for phase in self._current_phase: @@ -407,10 +403,10 @@ class Extractor(): return retval @staticmethod - def _set_flow(detector: Optional[str], - aligner: Optional[str], - masker: List[Optional[str]], - recognition: Optional[str]) -> List[str]: + def _set_flow(detector: str | None, + aligner: str | None, + masker: list[str | None], + recognition: str | None) -> list[str]: """ Set the flow list based on the input plugins Parameters @@ -441,7 +437,7 @@ class Extractor(): return retval @staticmethod - def _get_plugin_type_and_index(flow_phase: str) -> Tuple[str, Optional[int]]: + def _get_plugin_type_and_index(flow_phase: str) -> tuple[str, int | None]: """ Obtain the plugin type and index for the plugin for the given flow phase. When multiple plugins for the same phase are allowed (e.g. Mask) this will return @@ -463,14 +459,14 @@ class Extractor(): """ sidx = flow_phase.split("_")[-1] if sidx.isdigit(): - idx: Optional[int] = int(sidx) + idx: int | None = int(sidx) plugin_type = "_".join(flow_phase.split("_")[:-1]) else: plugin_type = flow_phase idx = None return plugin_type, idx - def _add_queues(self) -> Dict[str, EventQueue]: + def _add_queues(self) -> dict[str, EventQueue]: """ Add the required processing queues to Queue Manager """ queues = {} tasks = [f"extract{self._instance}_{phase}_in" for phase in self._flow] @@ -483,7 +479,7 @@ class Extractor(): return queues @staticmethod - def _get_vram_stats() -> Dict[str, Union[int, str]]: + def _get_vram_stats() -> dict[str, int | str]: """ Obtain statistics on available VRAM and subtract a constant buffer from available vram. Returns @@ -494,10 +490,10 @@ class Extractor(): vram_buffer = 256 # Leave a buffer for VRAM allocation gpu_stats = GPUStats() stats = gpu_stats.get_card_most_free() - retval: Dict[str, Union[int, str]] = {"count": gpu_stats.device_count, - "device": stats.device, - "vram_free": int(stats.free - vram_buffer), - "vram_total": int(stats.total)} + retval: dict[str, int | str] = {"count": gpu_stats.device_count, + "device": stats.device, + "vram_free": int(stats.free - vram_buffer), + "vram_total": int(stats.total)} logger.debug(retval) return retval @@ -521,13 +517,13 @@ class Extractor(): self._vram_stats["device"], self._vram_stats["vram_free"], self._vram_stats["vram_total"]) - if cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required: + if T.cast(int, self._vram_stats["vram_free"]) <= self._total_vram_required: logger.warning("Not enough free VRAM for parallel processing. " "Switching to serial") return False return True - def _set_phases(self, multiprocess: bool) -> List[List[str]]: + def _set_phases(self, multiprocess: bool) -> list[list[str]]: """ If not enough VRAM is available, then chunk :attr:`_flow` up into phases that will fit into VRAM, otherwise return the single flow. @@ -541,9 +537,9 @@ class Extractor(): list: The jobs to be undertaken split into phases that fit into GPU RAM """ - phases: List[List[str]] = [] - current_phase: List[str] = [] - available = cast(int, self._vram_stats["vram_free"]) + phases: list[list[str]] = [] + current_phase: list[str] = [] + available = T.cast(int, self._vram_stats["vram_free"]) for phase in self._flow: num_plugins = len([p for p in current_phase if self._vram_per_phase[p] > 0]) num_plugins += 1 if self._vram_per_phase[phase] > 0 else 0 @@ -576,12 +572,12 @@ class Extractor(): # << INTERNAL PLUGIN HANDLING >> # def _load_align(self, - aligner: Optional[str], - configfile: Optional[str], - normalize_method: Optional[Literal["none", "clahe", "hist", "mean"]], + aligner: str | None, + configfile: str | None, + normalize_method: T.Literal["none", "clahe", "hist", "mean"] | None, re_feed: int, re_align: bool, - disable_filter: bool) -> Optional["Aligner"]: + disable_filter: bool) -> Aligner | None: """ Set global arguments and load aligner plugin Parameters @@ -619,10 +615,10 @@ class Extractor(): return plugin def _load_detect(self, - detector: Optional[str], - rotation: Optional[str], + detector: str | None, + rotation: str | None, min_size: int, - configfile: Optional[str]) -> Optional["Detector"]: + configfile: str | None) -> Detector | None: """ Set global arguments and load detector plugin """ if detector is None or detector.lower() == "none": logger.debug("No detector selected. Returning None") @@ -637,8 +633,8 @@ class Extractor(): return plugin def _load_mask(self, - masker: Optional[str], - configfile: Optional[str]) -> Optional["Masker"]: + masker: str | None, + configfile: str | None) -> Masker | None: """ Set global arguments and load masker plugin Parameters @@ -664,8 +660,8 @@ class Extractor(): return plugin def _load_recognition(self, - recognition: Optional[str], - configfile: Optional[str]) -> Optional["Identity"]: + recognition: str | None, + configfile: str | None) -> Identity | None: """ Set global arguments and load recognition plugin """ if recognition is None or recognition.lower() == "none": logger.debug("No recognition selected. Returning None") @@ -716,16 +712,16 @@ class Extractor(): gpu_plugins = [p for p in self._current_phase if self._vram_per_phase[p] > 0] scaling = self._parallel_scaling.get(len(gpu_plugins), self._scaling_fallback) plugins_required = sum(self._vram_per_phase[p] for p in gpu_plugins) * scaling - if plugins_required + batch_required <= cast(int, self._vram_stats["vram_free"]): + if plugins_required + batch_required <= T.cast(int, self._vram_stats["vram_free"]): logger.debug("Plugin requirements within threshold: (plugins_required: %sMB, " "vram_free: %sMB)", plugins_required, self._vram_stats["vram_free"]) return # Hacky split across plugins that use vram - available_vram = (cast(int, self._vram_stats["vram_free"]) + available_vram = (T.cast(int, self._vram_stats["vram_free"]) - plugins_required) // len(gpu_plugins) self._set_plugin_batchsize(gpu_plugins, available_vram) - def _set_plugin_batchsize(self, gpu_plugins: List[str], available_vram: float) -> None: + def _set_plugin_batchsize(self, gpu_plugins: list[str], available_vram: float) -> None: """ Set the batch size for the given plugin based on given available vram. Do not update plugins which have a vram_per_batch of 0 (CPU plugins) due to zero division error. @@ -802,20 +798,20 @@ class ExtractMedia(): def __init__(self, filename: str, - image: "np.ndarray", - detected_faces: Optional[List["DetectedFace"]] = None, + image: np.ndarray, + detected_faces: list[DetectedFace] | None = None, is_aligned: bool = False) -> None: logger.trace("Initializing %s: (filename: '%s', image shape: %s, " # type: ignore "detected_faces: %s, is_aligned: %s)", self.__class__.__name__, filename, image.shape, detected_faces, is_aligned) self._filename = filename - self._image: Optional["np.ndarray"] = image - self._image_shape = cast(Tuple[int, int, int], image.shape) - self._detected_faces: List["DetectedFace"] = ([] if detected_faces is None - else detected_faces) + self._image: np.ndarray | None = image + self._image_shape = T.cast(tuple[int, int, int], image.shape) + self._detected_faces: list[DetectedFace] = ([] if detected_faces is None + else detected_faces) self._is_aligned = is_aligned - self._frame_metadata: Optional["PNGHeaderSourceDict"] = None - self._sub_folders: List[Optional[str]] = [] + self._frame_metadata: PNGHeaderSourceDict | None = None + self._sub_folders: list[str | None] = [] @property def filename(self) -> str: @@ -823,23 +819,23 @@ class ExtractMedia(): return self._filename @property - def image(self) -> "np.ndarray": + def image(self) -> np.ndarray: """ :class:`numpy.ndarray`: The source frame for this object. """ assert self._image is not None return self._image @property - def image_shape(self) -> Tuple[int, int, int]: + def image_shape(self) -> tuple[int, int, int]: """ tuple: The shape of the stored :attr:`image`. """ return self._image_shape @property - def image_size(self) -> Tuple[int, int]: + def image_size(self) -> tuple[int, int]: """ tuple: The (`height`, `width`) of the stored :attr:`image`. """ return self._image_shape[:2] @property - def detected_faces(self) -> List["DetectedFace"]: + def detected_faces(self) -> list[DetectedFace]: """list: A list of :class:`~lib.align.DetectedFace` objects in the :attr:`image`. """ return self._detected_faces @@ -849,7 +845,7 @@ class ExtractMedia(): return self._is_aligned @property - def frame_metadata(self) -> "PNGHeaderSourceDict": + def frame_metadata(self) -> PNGHeaderSourceDict: """ dict: The frame metadata that has been added from an aligned image. This property should only be called after :func:`add_frame_metadata` has been called when processing an aligned face. For all other instances an assertion error will be raised. @@ -863,13 +859,13 @@ class ExtractMedia(): return self._frame_metadata @property - def sub_folders(self) -> List[Optional[str]]: + def sub_folders(self) -> list[str | None]: """ list: The sub_folders that the faces should be output to. Used when binning filter output is enabled. The list corresponds to the list of detected faces """ return self._sub_folders - def get_image_copy(self, color_format: Literal["BGR", "RGB", "GRAY"]) -> "np.ndarray": + def get_image_copy(self, color_format: T.Literal["BGR", "RGB", "GRAY"]) -> np.ndarray: """ Get a copy of the image in the requested color format. Parameters @@ -887,7 +883,7 @@ class ExtractMedia(): image = getattr(self, f"_image_as_{color_format.lower()}")() return image - def add_detected_faces(self, faces: List["DetectedFace"]) -> None: + def add_detected_faces(self, faces: list[DetectedFace]) -> None: """ Add detected faces to the object. Called at the end of each extraction phase. Parameters @@ -900,7 +896,7 @@ class ExtractMedia(): [(face.left, face.right, face.top, face.bottom) for face in faces]) self._detected_faces = faces - def add_sub_folders(self, folders: List[Optional[str]]) -> None: + def add_sub_folders(self, folders: list[str | None]) -> None: """ Add detected faces to the object. Called at the end of each extraction phase. Parameters @@ -922,7 +918,7 @@ class ExtractMedia(): del self._image self._image = None - def set_image(self, image: "np.ndarray") -> None: + def set_image(self, image: np.ndarray) -> None: """ Add the image back into :attr:`image` Required for multi-phase extraction adds the image back to this object. @@ -936,7 +932,7 @@ class ExtractMedia(): self._filename, image.shape) self._image = image - def add_frame_metadata(self, metadata: "PNGHeaderSourceDict") -> None: + def add_frame_metadata(self, metadata: PNGHeaderSourceDict) -> None: """ Add the source frame metadata from an aligned PNG's header data. metadata: dict @@ -944,11 +940,11 @@ class ExtractMedia(): """ logger.trace("Adding PNG Source data for '%s': %s", # type:ignore self._filename, metadata) - dims = cast(Tuple[int, int], metadata["source_frame_dims"]) + dims = T.cast(tuple[int, int], metadata["source_frame_dims"]) self._image_shape = (*dims, 3) self._frame_metadata = metadata - def _image_as_bgr(self) -> "np.ndarray": + def _image_as_bgr(self) -> np.ndarray: """ Get a copy of the source frame in BGR format. Returns @@ -957,7 +953,7 @@ class ExtractMedia(): A copy of :attr:`image` in BGR color format """ return self.image[..., :3].copy() - def _image_as_rgb(self) -> "np.ndarray": + def _image_as_rgb(self) -> np.ndarray: """ Get a copy of the source frame in RGB format. Returns @@ -966,7 +962,7 @@ class ExtractMedia(): A copy of :attr:`image` in RGB color format """ return self.image[..., 2::-1].copy() - def _image_as_gray(self) -> "np.ndarray": + def _image_as_gray(self) -> np.ndarray: """ Get a copy of the source frame in gray-scale format. Returns diff --git a/plugins/extract/recognition/_base.py b/plugins/extract/recognition/_base.py index bf5a7372..3630607b 100644 --- a/plugins/extract/recognition/_base.py +++ b/plugins/extract/recognition/_base.py @@ -17,7 +17,6 @@ To get a :class:`~lib.align.DetectedFace` object use the function: """ from __future__ import annotations import logging -import sys import typing as T from dataclasses import dataclass, field @@ -31,13 +30,8 @@ from lib.utils import FaceswapError from plugins.extract._base import BatchType, Extractor, ExtractorBatch from plugins.extract.pipeline import ExtractMedia -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - - if T.TYPE_CHECKING: + from collections.abc import Generator from queue import Queue from lib.align.aligned_face import CenteringType @@ -50,8 +44,8 @@ class RecogBatch(ExtractorBatch): Inherits from :class:`~plugins.extract._base.ExtractorBatch` """ - detected_faces: T.List["DetectedFace"] = field(default_factory=list) - feed_faces: T.List[AlignedFace] = field(default_factory=list) + detected_faces: list[DetectedFace] = field(default_factory=list) + feed_faces: list[AlignedFace] = field(default_factory=list) class Identity(Extractor): # pylint:disable=abstract-method @@ -82,9 +76,9 @@ class Identity(Extractor): # pylint:disable=abstract-method """ def __init__(self, - git_model_id: T.Optional[int] = None, - model_filename: T.Optional[str] = None, - configfile: T.Optional[str] = None, + git_model_id: int | None = None, + model_filename: str | None = None, + configfile: str | None = None, instance: int = 0, **kwargs): logger.debug("Initializing %s", self.__class__.__name__) @@ -119,7 +113,7 @@ class Identity(Extractor): # pylint:disable=abstract-method logger.debug("Obtained detected face: (filename: %s, detected_face: %s)", item.filename, item.detected_faces) - def get_batch(self, queue: Queue) -> T.Tuple[bool, RecogBatch]: + def get_batch(self, queue: Queue) -> tuple[bool, RecogBatch]: """ Get items for inputting into the recognition from the queue in batches Items are returned from the ``queue`` in batches of @@ -226,7 +220,7 @@ class Identity(Extractor): # pylint:disable=abstract-method "\n3) Enable 'Single Process' mode.") raise FaceswapError(msg) from err - def finalize(self, batch: BatchType) -> T.Generator[ExtractMedia, None, None]: + def finalize(self, batch: BatchType) -> Generator[ExtractMedia, None, None]: """ Finalize the output from Masker This should be called as the final task of each `plugin`. @@ -301,8 +295,8 @@ class IdentityFilter(): def __init__(self, save_output: bool) -> None: logger.debug("Initializing %s: (save_output: %s)", self.__class__.__name__, save_output) self._save_output = save_output - self._filter: T.Optional[np.ndarray] = None - self._nfilter: T.Optional[np.ndarray] = None + self._filter: np.ndarray | None = None + self._nfilter: np.ndarray | None = None self._threshold = 0.0 self._filter_enabled: bool = False self._nfilter_enabled: bool = False @@ -357,7 +351,7 @@ class IdentityFilter(): return retval def _get_matches(self, - filter_type: Literal["filter", "nfilter"], + filter_type: T.Literal["filter", "nfilter"], identities: np.ndarray) -> np.ndarray: """ Obtain the average and minimum distances for each face against the source identities to test against @@ -386,9 +380,9 @@ class IdentityFilter(): return retval def _filter_faces(self, - faces: T.List[DetectedFace], - sub_folders: T.List[T.Optional[str]], - should_filter: T.List[bool]) -> T.List[DetectedFace]: + faces: list[DetectedFace], + sub_folders: list[str | None], + should_filter: list[bool]) -> list[DetectedFace]: """ Filter the detected faces, either removing filtered faces from the list of detected faces or setting the output subfolder to `"_identity_filt"` for any filtered faces if saving output is enabled. @@ -410,7 +404,7 @@ class IdentityFilter(): The filtered list of detected face objects, if saving filtered faces has not been selected or the full list of detected faces """ - retval: T.List[DetectedFace] = [] + retval: list[DetectedFace] = [] self._counts += sum(should_filter) for idx, face in enumerate(faces): fldr = sub_folders[idx] @@ -429,8 +423,8 @@ class IdentityFilter(): return retval def __call__(self, - faces: T.List[DetectedFace], - sub_folders: T.List[T.Optional[str]]) -> T.List[DetectedFace]: + faces: list[DetectedFace], + sub_folders: list[str | None]) -> list[DetectedFace]: """ Call the identity filter function Parameters @@ -459,14 +453,14 @@ class IdentityFilter(): logger.trace("All faces already filtered: %s", sub_folders) # type: ignore return faces - should_filter: T.List[np.ndarray] = [] - for f_type in get_args(Literal["filter", "nfilter"]): + should_filter: list[np.ndarray] = [] + for f_type in T.get_args(T.Literal["filter", "nfilter"]): if not getattr(self, f"_{f_type}_enabled"): continue should_filter.append(self._get_matches(f_type, identities)) # If any of the filter or nfilter evaluate to 'should filter' then filter out face - final_filter: T.List[bool] = np.array(should_filter).max(axis=0).tolist() + final_filter: list[bool] = np.array(should_filter).max(axis=0).tolist() logger.trace("should_filter: %s, final_filter: %s", # type: ignore should_filter, final_filter) return self._filter_faces(faces, sub_folders, final_filter) diff --git a/plugins/extract/recognition/vgg_face2.py b/plugins/extract/recognition/vgg_face2.py index f7c26427..ae717c75 100644 --- a/plugins/extract/recognition/vgg_face2.py +++ b/plugins/extract/recognition/vgg_face2.py @@ -1,10 +1,9 @@ #!/usr/bin python3 """ VGG_Face2 inference and sorting """ +from __future__ import annotations import logging -import sys - -from typing import cast, Dict, Generator, List, Tuple, Optional +import typing as T import numpy as np import psutil @@ -15,11 +14,8 @@ from lib.model.session import KSession from lib.utils import FaceswapError from ._base import BatchType, RecogBatch, Identity - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal +if T.TYPE_CHECKING: + from collections.abc import Generator logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -64,7 +60,7 @@ class Recognition(Identity): def init_model(self) -> None: """ Initialize VGG Face 2 Model. """ assert isinstance(self.model_path, str) - model_kwargs = dict(custom_objects={'L2_normalize': L2_normalize}) + model_kwargs = {"custom_objects": {"L2_normalize": L2_normalize}} self.model = KSession(self.name, self.model_path, model_kwargs=model_kwargs, @@ -76,7 +72,7 @@ class Recognition(Identity): def process_input(self, batch: BatchType) -> None: """ Compile the detected faces for prediction """ assert isinstance(batch, RecogBatch) - batch.feed = np.array([cast(np.ndarray, feed.face)[..., :3] + batch.feed = np.array([T.cast(np.ndarray, feed.face)[..., :3] for feed in batch.feed_faces], dtype="float32") - self._average_img logger.trace("feed shape: %s", batch.feed.shape) # type:ignore @@ -121,15 +117,15 @@ class Cluster(): # pylint: disable=too-few-public-methods def __init__(self, predictions: np.ndarray, - method: Literal["single", "centroid", "median", "ward"], - threshold: Optional[float] = None) -> None: + method: T.Literal["single", "centroid", "median", "ward"], + threshold: float | None = None) -> None: logger.debug("Initializing: %s (predictions: %s, method: %s, threshold: %s)", self.__class__.__name__, predictions.shape, method, threshold) self._num_predictions = predictions.shape[0] self._should_output_bins = threshold is not None self._threshold = 0.0 if threshold is None else threshold - self._bins: Dict[int, int] = {} + self._bins: dict[int, int] = {} self._iterator = self._integer_iterator() self._result_linkage = self._do_linkage(predictions, method) @@ -192,7 +188,7 @@ class Cluster(): # pylint: disable=too-few-public-methods def _do_linkage(self, predictions: np.ndarray, - method: Literal["single", "centroid", "median", "ward"]) -> np.ndarray: + method: T.Literal["single", "centroid", "median", "ward"]) -> np.ndarray: """ Use FastCluster to perform vector or standard linkage Parameters @@ -218,7 +214,7 @@ class Cluster(): # pylint: disable=too-few-public-methods def _process_leaf_node(self, current_index: int, - current_bin: int) -> List[Tuple[int, int]]: + current_bin: int) -> list[tuple[int, int]]: """ Process the output when we have hit a leaf node """ if not self._should_output_bins: return [(current_index, 0)] @@ -263,7 +259,7 @@ class Cluster(): # pylint: disable=too-few-public-methods tree: np.ndarray, points: int, current_index: int, - current_bin: int = 0) -> List[Tuple[int, int]]: + current_bin: int = 0) -> list[tuple[int, int]]: """ Seriation method for sorted similarity. Seriation computes the order implied by a hierarchical tree (dendrogram). @@ -298,7 +294,7 @@ class Cluster(): # pylint: disable=too-few-public-methods return serate_left + serate_right # type: ignore - def __call__(self) -> List[Tuple[int, int]]: + def __call__(self) -> list[tuple[int, int]]: """ Process the linkages. Transforms a distance matrix into a sorted distance matrix according to the order implied diff --git a/plugins/plugin_loader.py b/plugins/plugin_loader.py index 6c60b359..30b97621 100644 --- a/plugins/plugin_loader.py +++ b/plugins/plugin_loader.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 """ Plugin loader for Faceswap extract, training and convert tasks """ - +from __future__ import annotations import logging import os -import sys -from importlib import import_module -from typing import Callable, List, Type, TYPE_CHECKING +import typing as T -if TYPE_CHECKING: +from importlib import import_module + +if T.TYPE_CHECKING: + from collections.abc import Callable from plugins.extract.detect._base import Detector from plugins.extract.align._base import Aligner from plugins.extract.mask._base import Masker @@ -15,11 +16,6 @@ if TYPE_CHECKING: from plugins.train.model._base import ModelBase from plugins.train.trainer._base import TrainerBase -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -36,7 +32,7 @@ class PluginLoader(): >>> aligner = PluginLoader.get_aligner('cv2-dnn') """ @staticmethod - def get_detector(name: str, disable_logging: bool = False) -> Type["Detector"]: + def get_detector(name: str, disable_logging: bool = False) -> type[Detector]: """ Return requested detector plugin Parameters @@ -55,7 +51,7 @@ class PluginLoader(): return PluginLoader._import("extract.detect", name, disable_logging) @staticmethod - def get_aligner(name: str, disable_logging: bool = False) -> Type["Aligner"]: + def get_aligner(name: str, disable_logging: bool = False) -> type[Aligner]: """ Return requested aligner plugin Parameters @@ -74,7 +70,7 @@ class PluginLoader(): return PluginLoader._import("extract.align", name, disable_logging) @staticmethod - def get_masker(name: str, disable_logging: bool = False) -> Type["Masker"]: + def get_masker(name: str, disable_logging: bool = False) -> type[Masker]: """ Return requested masker plugin Parameters @@ -93,7 +89,7 @@ class PluginLoader(): return PluginLoader._import("extract.mask", name, disable_logging) @staticmethod - def get_recognition(name: str, disable_logging: bool = False) -> Type["Identity"]: + def get_recognition(name: str, disable_logging: bool = False) -> type[Identity]: """ Return requested recognition plugin Parameters @@ -112,7 +108,7 @@ class PluginLoader(): return PluginLoader._import("extract.recognition", name, disable_logging) @staticmethod - def get_model(name: str, disable_logging: bool = False) -> Type["ModelBase"]: + def get_model(name: str, disable_logging: bool = False) -> type[ModelBase]: """ Return requested training model plugin Parameters @@ -131,7 +127,7 @@ class PluginLoader(): return PluginLoader._import("train.model", name, disable_logging) @staticmethod - def get_trainer(name: str, disable_logging: bool = False) -> Type["TrainerBase"]: + def get_trainer(name: str, disable_logging: bool = False) -> type[TrainerBase]: """ Return requested training trainer plugin Parameters @@ -198,9 +194,9 @@ class PluginLoader(): return getattr(module, ttl) @staticmethod - def get_available_extractors(extractor_type: Literal["align", "detect", "mask"], + def get_available_extractors(extractor_type: T.Literal["align", "detect", "mask"], add_none: bool = False, - extend_plugin: bool = False) -> List[str]: + extend_plugin: bool = False) -> list[str]: """ Return a list of available extractors of the given type Parameters @@ -243,7 +239,7 @@ class PluginLoader(): return extractors @staticmethod - def get_available_models() -> List[str]: + def get_available_models() -> list[str]: """ Return a list of available training models Returns @@ -273,7 +269,7 @@ class PluginLoader(): return 'original' if 'original' in models else models[0] @staticmethod - def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> List[str]: + def get_available_convert_plugins(convert_category: str, add_none: bool = True) -> list[str]: """ Return a list of available converter plugins in the given category Parameters diff --git a/plugins/train/_config.py b/plugins/train/_config.py index 0e8efe37..dbfc0ffe 100644 --- a/plugins/train/_config.py +++ b/plugins/train/_config.py @@ -249,6 +249,31 @@ class Config(FaceswapConfig): "NB: The value given here is the 'exponent' to the epsilon. For example, " "choosing '-7' will set the epsilon to 1e-7. Choosing '-3' will set the epsilon " "to 0.001 (1e-3).")) + self.add_item( + section=section, + title="save_optimizer", + datatype=str, + group=_("optimizer"), + default="exit", + fixed=False, + gui_radio=True, + choices=["never", "always", "exit"], + info=_( + "When to save the Optimizer Weights. Saving the optimizer weights is not " + "necessary and will increase the model file size 3x (and by extension the amount " + "of time it takes to save the model). However, it can be useful to save these " + "weights if you want to guarantee that a resumed model carries off exactly from " + "where it left off, rather than spending a few hundred iterations catching up." + "\n\t never - Don't save optimizer weights." + "\n\t always - Save the optimizer weights at every save iteration. Model saving " + "will take longer, due to the increased file size, but you will always have the " + "last saved optimizer state in your model file." + "\n\t exit - Only save the optimizer weights when explicitly terminating a " + "model. This can be when the model is actively stopped or when the target " + "iterations are met. Note: If the training session ends because of another " + "reason (e.g. power outage, Out of Memory Error, NaN detected) then the " + "optimizer weights will NOT be saved.")) + self.add_item( section=section, title="autoclip", diff --git a/plugins/train/model/_base/io.py b/plugins/train/model/_base/io.py index d8d85157..6bbd0c44 100644 --- a/plugins/train/model/_base/io.py +++ b/plugins/train/model/_base/io.py @@ -21,11 +21,6 @@ from tensorflow.keras.models import load_model, Model as KModel # noqa:E501 # from lib.model.backup_restore import Backup from lib.utils import FaceswapError -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - if T.TYPE_CHECKING: from tensorflow import keras from .model import ModelBase @@ -35,7 +30,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name def get_all_sub_models( model: keras.models.Model, - models: T.Optional[T.List[keras.models.Model]] = None) -> T.List[keras.models.Model]: + models: list[keras.models.Model] | None = None) -> list[keras.models.Model]: """ For a given model, return all sub-models that occur (recursively) as children. Parameters @@ -85,12 +80,12 @@ class IO(): plugin: ModelBase, model_dir: str, is_predict: bool, - save_optimizer: Literal["never", "always", "exit"]) -> None: + save_optimizer: T.Literal["never", "always", "exit"]) -> None: self._plugin = plugin self._is_predict = is_predict self._model_dir = model_dir self._save_optimizer = save_optimizer - self._history: T.List[T.List[float]] = [[], []] # Loss histories per save iteration + self._history: list[list[float]] = [[], []] # Loss histories per save iteration self._backup = Backup(self._model_dir, self._plugin.name) @property @@ -106,12 +101,12 @@ class IO(): return os.path.isfile(self._filename) @property - def history(self) -> T.List[T.List[float]]: + def history(self) -> list[list[float]]: """ list: list of loss histories per side for the current save iteration. """ return self._history @property - def multiple_models_in_folder(self) -> T.Optional[T.List[str]]: + def multiple_models_in_folder(self) -> list[str] | None: """ :list: or ``None`` If there are multiple model types in the requested folder, or model types that don't correspond to the requested plugin type, then returns the list of plugin names that exist in the folder, otherwise returns ``None`` """ @@ -210,7 +205,7 @@ class IO(): msg += f" - Average loss since last save: {', '.join(lossmsg)}" logger.info(msg) - def _get_save_averages(self) -> T.List[float]: + def _get_save_averages(self) -> list[float]: """ Return the average loss since the last save iteration and reset historical loss """ logger.debug("Getting save averages") if not all(loss for loss in self._history): @@ -222,7 +217,7 @@ class IO(): logger.debug("Average losses since last save: %s", retval) return retval - def _should_backup(self, save_averages: T.List[float]) -> bool: + def _should_backup(self, save_averages: list[float]) -> bool: """ Check whether the loss averages for this save iteration is the lowest that has been seen. @@ -301,7 +296,7 @@ class Weights(): logger.debug("Initialized %s", self.__class__.__name__) @classmethod - def _check_weights_file(cls, weights_file: str) -> T.Optional[str]: + def _check_weights_file(cls, weights_file: str) -> str | None: """ Validate that we have a valid path to a .h5 file. Parameters @@ -403,7 +398,7 @@ class Weights(): "different settings than you have set for your current model.", skipped_ops) - def _get_weights_model(self) -> T.List[keras.models.Model]: + def _get_weights_model(self) -> list[keras.models.Model]: """ Obtain a list of all sub-models contained within the weights model. Returns @@ -429,7 +424,7 @@ class Weights(): def _load_layer_weights(self, layer: keras.layers.Layer, sub_weights: keras.layers.Layer, - model_name: str) -> Literal[-1, 0, 1]: + model_name: str) -> T.Literal[-1, 0, 1]: """ Load the weights for a single layer. Parameters diff --git a/plugins/train/model/_base/model.py b/plugins/train/model/_base/model.py index 6eba5f33..21f0cca6 100644 --- a/plugins/train/model/_base/model.py +++ b/plugins/train/model/_base/model.py @@ -29,18 +29,12 @@ from plugins.train._config import Config from .io import IO, get_all_sub_models, Weights from .settings import Loss, Optimizer, Settings - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - if T.TYPE_CHECKING: import argparse from lib.config import ConfigValueType logger = logging.getLogger(__name__) # pylint: disable=invalid-name -_CONFIG: T.Dict[str, ConfigValueType] = {} +_CONFIG: dict[str, ConfigValueType] = {} class ModelBase(): @@ -79,13 +73,13 @@ class ModelBase(): self.__class__.__name__, model_dir, arguments, predict) # Input shape must be set within the plugin after initializing - self.input_shape: T.Tuple[int, ...] = () + self.input_shape: tuple[int, ...] = () self.trainer = "original" # Override for plugin specific trainer - self.color_order: Literal["bgr", "rgb"] = "bgr" # Override for image color channel order + self.color_order: T.Literal["bgr", "rgb"] = "bgr" # Override for image color channel order self._args = arguments self._is_predict = predict - self._model: T.Optional[tf.keras.models.Model] = None + self._model: tf.keras.models.Model | None = None self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None self._load_config() @@ -100,14 +94,7 @@ class ModelBase(): "use. Please select a mask or disable 'Learn Mask'.") self._mixed_precision = self.config["mixed_precision"] - # self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"]) - # TODO - Re-enable saving of optimizer once this bug is fixed: - # File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper - # File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper - # File "h5py/h5d.pyx", line 87, in h5py.h5d.create - # ValueError: Unable to create dataset (name already exists) - - self._io = IO(self, model_dir, self._is_predict, "never") + self._io = IO(self, model_dir, self._is_predict, self.config["save_optimizer"]) self._check_multiple_models() self._state = State(model_dir, @@ -175,16 +162,16 @@ class ModelBase(): return self.name @property - def input_shapes(self) -> T.List[T.Tuple[None, int, int, int]]: + def input_shapes(self) -> list[tuple[None, int, int, int]]: """ list: A flattened list corresponding to all of the inputs to the model. """ - shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(inputs)) + shapes = [T.cast(tuple[None, int, int, int], K.int_shape(inputs)) for inputs in self.model.inputs] return shapes @property - def output_shapes(self) -> T.List[T.Tuple[None, int, int, int]]: + def output_shapes(self) -> list[tuple[None, int, int, int]]: """ list: A flattened list corresponding to all of the outputs of the model. """ - shapes = [T.cast(T.Tuple[None, int, int, int], K.int_shape(output)) + shapes = [T.cast(tuple[None, int, int, int], K.int_shape(output)) for output in self.model.outputs] return shapes @@ -333,7 +320,7 @@ class ModelBase(): a list of 2 shape tuples of 3 dimensions. """ assert len(self.input_shape) == 3, "Input shape should be a 3 dimensional shape tuple" - def _get_inputs(self) -> T.List[tf.keras.layers.Input]: + def _get_inputs(self) -> list[tf.keras.layers.Input]: """ Obtain the standardized inputs for the model. The inputs will be returned for the "A" and "B" sides in the shape as defined by @@ -352,7 +339,7 @@ class ModelBase(): logger.debug("inputs: %s", inputs) return inputs - def build_model(self, inputs: T.List[tf.keras.layers.Input]) -> tf.keras.models.Model: + def build_model(self, inputs: list[tf.keras.layers.Input]) -> tf.keras.models.Model: """ Override for Model Specific autoencoder builds. Parameters @@ -427,7 +414,7 @@ class ModelBase(): self._state.add_session_loss_names(self._loss.names) logger.debug("Compiled Model: %s", self.model) - def _legacy_mapping(self) -> T.Optional[dict]: + def _legacy_mapping(self) -> dict | None: """ The mapping of separate model files to single model layers for transferring of legacy weights. @@ -439,7 +426,7 @@ class ModelBase(): """ return None - def add_history(self, loss: T.List[float]) -> None: + def add_history(self, loss: list[float]) -> None: """ Add the current iteration's loss history to :attr:`_io.history`. Called from the trainer after each iteration, for tracking loss drop over time between @@ -482,18 +469,18 @@ class State(): self._filename = os.path.join(model_dir, filename) self._name = model_name self._iterations = 0 - self._mixed_precision_layers: T.List[str] = [] + self._mixed_precision_layers: list[str] = [] self._rebuild_model = False - self._sessions: T.Dict[int, dict] = {} - self._lowest_avg_loss: T.Dict[str, float] = {} - self._config: T.Dict[str, ConfigValueType] = {} + self._sessions: dict[int, dict] = {} + self._lowest_avg_loss: dict[str, float] = {} + self._config: dict[str, ConfigValueType] = {} self._load(config_changeable_items) self._session_id = self._new_session_id() self._create_new_session(no_logs, config_changeable_items) logger.debug("Initialized %s:", self.__class__.__name__) @property - def loss_names(self) -> T.List[str]: + def loss_names(self) -> list[str]: """ list: The loss names for the current session """ return self._sessions[self._session_id]["loss_names"] @@ -518,7 +505,7 @@ class State(): return self._session_id @property - def mixed_precision_layers(self) -> T.List[str]: + def mixed_precision_layers(self) -> list[str]: """list: Layers that can be switched between mixed-float16 and float32. """ return self._mixed_precision_layers @@ -564,7 +551,7 @@ class State(): "iterations": 0, "config": config_changeable_items} - def add_session_loss_names(self, loss_names: T.List[str]) -> None: + def add_session_loss_names(self, loss_names: list[str]) -> None: """ Add the session loss names to the sessions dictionary. The loss names are used for Tensorboard logging @@ -593,7 +580,7 @@ class State(): self._iterations += 1 self._sessions[self._session_id]["iterations"] += 1 - def add_mixed_precision_layers(self, layers: T.List[str]) -> None: + def add_mixed_precision_layers(self, layers: list[str]) -> None: """ Add the list of model's layers that are compatible for mixed precision to the state dictionary """ logger.debug("Storing mixed precision layers: %s", layers) @@ -655,11 +642,11 @@ class State(): legacy_update = self._update_legacy_config() # Add any new items to state config for legacy purposes where the new default may be # detrimental to an existing model. - legacy_defaults: T.Dict[str, T.Union[str, int, bool]] = {"centering": "legacy", - "mask_loss_function": "mse", - "l2_reg_term": 100, - "optimizer": "adam", - "mixed_precision": False} + legacy_defaults: dict[str, str | int | bool] = {"centering": "legacy", + "mask_loss_function": "mse", + "l2_reg_term": 100, + "optimizer": "adam", + "mixed_precision": False} for key, val in _CONFIG.items(): if key not in self._config.keys(): setting: ConfigValueType = legacy_defaults.get(key, val) @@ -807,7 +794,7 @@ class _Inference(): # pylint:disable=too-few-public-methods """ :class:`keras.models.Model`: The Faceswap model, compiled for inference. """ return self._model - def _get_nodes(self, nodes: np.ndarray) -> T.List[T.Tuple[str, int]]: + def _get_nodes(self, nodes: np.ndarray) -> list[tuple[str, int]]: """ Given in input list of nodes from a :attr:`keras.models.Model.get_config` dictionary, filters the layer name(s) and output index of the node, splitting to the correct output index in the event of multiple inputs. @@ -849,7 +836,7 @@ class _Inference(): # pylint:disable=too-few-public-methods logger.debug("Compiling inference model. saved_model: %s", saved_model) struct = self._get_filtered_structure() model_inputs = self._get_inputs(saved_model.inputs) - compiled_layers: T.Dict[str, tf.keras.layers.Layer] = {} + compiled_layers: dict[str, tf.keras.layers.Layer] = {} for layer in saved_model.layers: if layer.name not in struct: logger.debug("Skipping unused layer: '%s'", layer.name) diff --git a/plugins/train/model/_base/settings.py b/plugins/train/model/_base/settings.py index 84ac102b..0513ff57 100644 --- a/plugins/train/model/_base/settings.py +++ b/plugins/train/model/_base/settings.py @@ -14,7 +14,6 @@ from __future__ import annotations from dataclasses import dataclass, field import logging import platform -import sys import typing as T from contextlib import nullcontext @@ -28,12 +27,9 @@ from lib.model import losses, optimizers from lib.model.autoclip import AutoClipper from lib.utils import get_backend -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - if T.TYPE_CHECKING: + from collections.abc import Callable + from contextlib import AbstractContextManager as ContextManager from argparse import Namespace from .model import State @@ -58,9 +54,9 @@ class LossClass: kwargs: dict Any keyword arguments to supply to the loss function at initialization. """ - function: T.Union[T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor], T.Any] = k_losses.mae + function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor] | T.Any = k_losses.mae init: bool = True - kwargs: T.Dict[str, T.Any] = field(default_factory=dict) + kwargs: dict[str, T.Any] = field(default_factory=dict) class Loss(): @@ -73,13 +69,13 @@ class Loss(): color_order: str Color order of the model. One of `"BGR"` or `"RGB"` """ - def __init__(self, config: dict, color_order: Literal["bgr", "rgb"]) -> None: + def __init__(self, config: dict, color_order: T.Literal["bgr", "rgb"]) -> None: logger.debug("Initializing %s: (color_order: %s)", self.__class__.__name__, color_order) self._config = config self._mask_channels = self._get_mask_channels() - self._inputs: T.List[tf.keras.layers.Layer] = [] - self._names: T.List[str] = [] - self._funcs: T.Dict[str, T.Callable] = {} + self._inputs: list[tf.keras.layers.Layer] = [] + self._names: list[str] = [] + self._funcs: dict[str, Callable] = {} self._loss_dict = {"ffl": LossClass(function=losses.FocalFrequencyLoss), "flip": LossClass(function=losses.LDRFLIPLoss, @@ -104,7 +100,7 @@ class Loss(): logger.debug("Initialized: %s", self.__class__.__name__) @property - def names(self) -> T.List[str]: + def names(self) -> list[str]: """ list: The list of loss names for the model. """ return self._names @@ -114,14 +110,14 @@ class Loss(): return self._funcs @property - def _mask_inputs(self) -> T.Optional[list]: + def _mask_inputs(self) -> list | None: """ list: The list of input tensors to the model that contain the mask. Returns ``None`` if there is no mask input to the model. """ mask_inputs = [inp for inp in self._inputs if inp.name.startswith("mask")] return None if not mask_inputs else mask_inputs @property - def _mask_shapes(self) -> T.Optional[T.List[tuple]]: + def _mask_shapes(self) -> list[tuple] | None: """ list: The list of shape tuples for the mask input tensors for the model. Returns ``None`` if there is no mask input. """ if self._mask_inputs is None: @@ -141,7 +137,7 @@ class Loss(): self._set_loss_functions(model.output_names) self._names.insert(0, "total") - def _set_loss_names(self, outputs: T.List[tf.Tensor]) -> None: + def _set_loss_names(self, outputs: list[tf.Tensor]) -> None: """ Name the losses based on model output. This is used for correct naming in the state file, for display purposes only. @@ -173,7 +169,7 @@ class Loss(): self._names.append(f"{name}_{side}{suffix}") logger.debug(self._names) - def _get_function(self, name: str) -> T.Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: + def _get_function(self, name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: """ Obtain the requested Loss function Parameters @@ -191,7 +187,7 @@ class Loss(): logger.debug("Obtained loss function `%s` (%s)", name, retval) return retval - def _set_loss_functions(self, output_names: T.List[str]): + def _set_loss_functions(self, output_names: list[str]): """ Set the loss functions and their associated weights. Adds the loss functions to the :attr:`functions` dictionary. @@ -251,7 +247,7 @@ class Loss(): mask_channel=mask_channel) channel_idx += 1 - def _get_mask_channels(self) -> T.List[int]: + def _get_mask_channels(self) -> list[int]: """ Obtain the channels from the face targets that the masks reside in from the training data generator. @@ -311,8 +307,8 @@ class Optimizer(): # pylint:disable=too-few-public-methods {"beta_1": 0.5, "beta_2": 0.99, "epsilon": epsilon}), "rms-prop": (optimizers.RMSprop, {"epsilon": epsilon})} optimizer_info = valid_optimizers[optimizer] - self._optimizer: T.Callable = optimizer_info[0] - self._kwargs: T.Dict[str, T.Any] = optimizer_info[1] + self._optimizer: Callable = optimizer_info[0] + self._kwargs: dict[str, T.Any] = optimizer_info[1] self._configure(learning_rate, autoclip) logger.verbose("Using %s optimizer", optimizer.title()) # type:ignore[attr-defined] @@ -411,7 +407,7 @@ class Settings(): return mixedprecision.LossScaleOptimizer(optimizer) # pylint:disable=no-member @classmethod - def _set_tf_settings(cls, allow_growth: bool, exclude_devices: T.List[int]) -> None: + def _set_tf_settings(cls, allow_growth: bool, exclude_devices: list[int]) -> None: """ Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth. Enables the Tensorflow allow_growth option if requested in the command line arguments @@ -480,8 +476,8 @@ class Settings(): return True def _get_strategy(self, - strategy: Literal["default", "central-storage", "mirrored"] - ) -> T.Optional[tf.distribute.Strategy]: + strategy: T.Literal["default", "central-storage", "mirrored"] + ) -> tf.distribute.Strategy | None: """ If we are running on Nvidia backend and the strategy is not ``None`` then return the correct tensorflow distribution strategy, otherwise return ``None``. @@ -565,7 +561,7 @@ class Settings(): return tf.distribute.experimental.CentralStorageStrategy(parameter_device="/cpu:0") - def _get_mixed_precision_layers(self, layers: T.List[dict]) -> T.List[str]: + def _get_mixed_precision_layers(self, layers: list[dict]) -> list[str]: """ Obtain the names of the layers in a mixed precision model that have their dtype policy explicitly set to mixed-float16. @@ -595,7 +591,7 @@ class Settings(): logger.debug("Skipping unsupported layer: %s %s", layer["name"], dtype) return retval - def _switch_precision(self, layers: T.List[dict], compatible: T.List[str]) -> None: + def _switch_precision(self, layers: list[dict], compatible: list[str]) -> None: """ Switch a model's datatype between mixed-float16 and float32. Parameters @@ -624,9 +620,9 @@ class Settings(): config["dtype"] = policy def get_mixed_precision_layers(self, - build_func: T.Callable[[T.List[tf.keras.layers.Layer]], - tf.keras.models.Model], - inputs: T.List[tf.keras.layers.Layer]) -> T.List[str]: + build_func: Callable[[list[tf.keras.layers.Layer]], + tf.keras.models.Model], + inputs: list[tf.keras.layers.Layer]) -> list[str]: """ Get and store the mixed precision layers from a full precision enabled model. Parameters @@ -699,7 +695,7 @@ class Settings(): del model return new_model - def strategy_scope(self) -> T.ContextManager: + def strategy_scope(self) -> ContextManager: """ Return the strategy scope if we have set a strategy, otherwise return a null context. diff --git a/plugins/train/model/phaze_a.py b/plugins/train/model/phaze_a.py index 689b9c62..402c071a 100644 --- a/plugins/train/model/phaze_a.py +++ b/plugins/train/model/phaze_a.py @@ -4,7 +4,6 @@ # pylint: disable=too-many-lines from __future__ import annotations import logging -import sys import typing as T from dataclasses import dataclass @@ -27,16 +26,10 @@ from lib.utils import get_tf_version, FaceswapError from ._base import ModelBase, get_all_sub_models -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - if T.TYPE_CHECKING: from tensorflow import keras from tensorflow import Tensor - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -65,14 +58,14 @@ class _EncoderInfo: """ keras_name: str default_size: int - tf_min: T.Tuple[int, int] = (2, 0) - scaling: T.Tuple[int, int] = (0, 1) + tf_min: tuple[int, int] = (2, 0) + scaling: tuple[int, int] = (0, 1) min_size: int = 32 enforce_for_weights: bool = False - color_order: Literal["bgr", "rgb"] = "rgb" + color_order: T.Literal["bgr", "rgb"] = "rgb" -_MODEL_MAPPING: T.Dict[str, _EncoderInfo] = { +_MODEL_MAPPING: dict[str, _EncoderInfo] = { "densenet121": _EncoderInfo( keras_name="DenseNet121", default_size=224), "densenet169": _EncoderInfo( @@ -238,7 +231,7 @@ class Model(ModelBase): model = new_model return model - def _select_freeze_layers(self) -> T.List[str]: + def _select_freeze_layers(self) -> list[str]: """ Process the selected frozen layers and replace the `keras_encoder` option with the actual keras model name @@ -262,7 +255,7 @@ class Model(ModelBase): logger.debug("Removing 'keras_encoder' for '%s'", arch) return retval - def _get_input_shape(self) -> T.Tuple[int, int, int]: + def _get_input_shape(self) -> tuple[int, int, int]: """ Obtain the input shape for the model. Input shape is calculated from the selected Encoder's input size, scaled to the user @@ -316,7 +309,7 @@ class Model(ModelBase): f"minimum version required is {tf_min} whilst you have version " f"{tf_ver} installed.") - def build_model(self, inputs: T.List[Tensor]) -> keras.models.Model: + def build_model(self, inputs: list[Tensor]) -> keras.models.Model: """ Create the model's structure. Parameters @@ -341,7 +334,7 @@ class Model(ModelBase): autoencoder = KModel(inputs, outputs, name=self.model_name) return autoencoder - def _build_encoders(self, inputs: T.List[Tensor]) -> T.Dict[str, keras.models.Model]: + def _build_encoders(self, inputs: list[Tensor]) -> dict[str, keras.models.Model]: """ Build the encoders for Phaze-A Parameters @@ -362,7 +355,7 @@ class Model(ModelBase): def _build_fully_connected( self, - inputs: T.Dict[str, keras.models.Model]) -> T.Dict[str, T.List[keras.models.Model]]: + inputs: dict[str, keras.models.Model]) -> dict[str, list[keras.models.Model]]: """ Build the fully connected layers for Phaze-A Parameters @@ -407,8 +400,8 @@ class Model(ModelBase): def _build_g_blocks( self, - inputs: T.Dict[str, T.List[keras.models.Model]] - ) -> T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]]: + inputs: dict[str, list[keras.models.Model]] + ) -> dict[str, list[keras.models.Model] | keras.models.Model]: """ Build the g-block layers for Phaze-A. If a g-block has not been selected for this model, then the original `inters` models are @@ -440,10 +433,9 @@ class Model(ModelBase): logger.debug("G-Blocks: %s", retval) return retval - def _build_decoders( - self, - inputs: T.Dict[str, T.Union[T.List[keras.models.Model], keras.models.Model]] - ) -> T.Dict[str, keras.models.Model]: + def _build_decoders(self, + inputs: dict[str, list[keras.models.Model] | keras.models.Model] + ) -> dict[str, keras.models.Model]: """ Build the encoders for Phaze-A Parameters @@ -519,12 +511,12 @@ def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str) return var_x -def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny", "upscale_fast", - "upscale_hybrid", "upsample2d"], +def _get_upscale_layer(method: T.Literal["resize_images", "subpixel", "upscale_dny", + "upscale_fast", "upscale_hybrid", "upsample2d"], filters: int, - activation: T.Optional[str] = None, - upsamples: T.Optional[int] = None, - interpolation: T.Optional[str] = None) -> keras.layers.Layer: + activation: str | None = None, + upsamples: int | None = None, + interpolation: str | None = None) -> keras.layers.Layer: """ Obtain an instance of the requested upscale method. Parameters @@ -550,7 +542,7 @@ def _get_upscale_layer(method: Literal["resize_images", "subpixel", "upscale_dny The selected configured upscale layer """ if method == "upsample2d": - kwargs: T.Dict[str, T.Union[str, int]] = {} + kwargs: dict[str, str | int] = {} if upsamples: kwargs["size"] = upsamples if interpolation: @@ -571,7 +563,7 @@ def _get_curve(start_y: int, end_y: int, num_points: int, scale: float, - mode: Literal["full", "cap_max", "cap_min"] = "full") -> T.List[int]: + mode: T.Literal["full", "cap_max", "cap_min"] = "full") -> list[int]: """ Obtain a curve. For the given start and end y values, return the y co-ordinates of a curve for the given @@ -660,13 +652,13 @@ class Encoder(): # pylint:disable=too-few-public-methods config: dict The model configuration options """ - def __init__(self, input_shape: T.Tuple[int, ...], config: dict) -> None: + def __init__(self, input_shape: tuple[int, ...], config: dict) -> None: self.input_shape = input_shape self._config = config self._input_shape = input_shape @property - def _model_kwargs(self) -> T.Dict[str, T.Dict[str, T.Union[str, bool]]]: + def _model_kwargs(self) -> dict[str, dict[str, str | bool]]: """ dict: Configuration option for architecture mapped to optional kwargs. """ return {"mobilenet": {"alpha": self._config["mobilenet_width"], "depth_multiplier": self._config["mobilenet_depth"], @@ -677,7 +669,7 @@ class Encoder(): # pylint:disable=too-few-public-methods "include_preprocessing": False}} @property - def _selected_model(self) -> T.Tuple[_EncoderInfo, dict]: + def _selected_model(self) -> tuple[_EncoderInfo, dict]: """ tuple(dict, :class:`_EncoderInfo`): The selected encoder model and it's associated keyword arguments """ arch = self._config["enc_architecture"] @@ -832,7 +824,7 @@ class FullyConnected(): # pylint:disable=too-few-public-methods The user configuration dictionary """ def __init__(self, - side: Literal["a", "b", "both", "gblock", "shared"], + side: T.Literal["a", "b", "both", "gblock", "shared"], input_shape: tuple, config: dict) -> None: logger.debug("Initializing: %s (side: %s, input_shape: %s)", @@ -992,12 +984,12 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods and the Decoder. ``None`` will generate the full Upscale chain. An end index of -1 will generate the layers from the starting index to the final upscale. Default: ``None`` """ - _filters: T.List[int] = [] + _filters: list[int] = [] def __init__(self, - side: Literal["a", "b", "both", "shared"], + side: T.Literal["a", "b", "both", "shared"], config: dict, - layer_indicies: T.Optional[T.Tuple[int, int]] = None) -> None: + layer_indicies: tuple[int, int] | None = None) -> None: logger.debug("Initializing: %s (side: %s, layer_indicies: %s)", self.__class__.__name__, side, layer_indicies) self._side = side @@ -1126,7 +1118,7 @@ class UpscaleBlocks(): # pylint: disable=too-few-public-methods relu_alpha=0.2)(var_x) return var_x - def __call__(self, inputs: T.Union[Tensor, T.List[Tensor]]) -> T.Union[Tensor, T.List[Tensor]]: + def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]: """ Upscale Network. Parameters @@ -1203,8 +1195,8 @@ class GBlock(): # pylint:disable=too-few-public-methods The user configuration dictionary """ def __init__(self, - side: Literal["a", "b", "both"], - input_shapes: T.Union[list, tuple], + side: T.Literal["a", "b", "both"], + input_shapes: list | tuple, config: dict) -> None: logger.debug("Initializing: %s (side: %s, input_shapes: %s)", self.__class__.__name__, side, input_shapes) @@ -1284,8 +1276,8 @@ class Decoder(): # pylint:disable=too-few-public-methods The user configuration dictionary """ def __init__(self, - side: Literal["a", "b", "both"], - input_shape: T.Tuple[int, int, int], + side: T.Literal["a", "b", "both"], + input_shape: tuple[int, int, int], config: dict) -> None: logger.debug("Initializing: %s (side: %s, input_shape: %s)", self.__class__.__name__, side, input_shape) diff --git a/plugins/train/model/phaze_a_defaults.py b/plugins/train/model/phaze_a_defaults.py index c741ae09..9468609d 100644 --- a/plugins/train/model/phaze_a_defaults.py +++ b/plugins/train/model/phaze_a_defaults.py @@ -39,8 +39,6 @@ " the value saved in the state file with the updated value in config. If not " provided this will default to True. """ -from typing import List - _HELPTEXT: str = ( "Phaze-A Model by TorzDF, with thanks to BirbFakes.\n" @@ -48,7 +46,7 @@ _HELPTEXT: str = ( "inspiration from Nvidia's StyleGAN for the Decoder. It is highly recommended to research to " "understand the parameters better.") -_ENCODERS: List[str] = sorted([ +_ENCODERS: list[str] = sorted([ "densenet121", "densenet169", "densenet201", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7", "efficientnet_v2_b0", "efficientnet_v2_b1", "efficientnet_v2_b2", diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py index f0feebb4..146f7b70 100644 --- a/plugins/train/trainer/_base.py +++ b/plugins/train/trainer/_base.py @@ -9,7 +9,6 @@ with "original" unique code split out to the original plugin. from __future__ import annotations import logging import os -import sys import time import typing as T @@ -23,23 +22,19 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module from lib.image import hex_to_rgb from lib.training import PreviewDataGenerator, TrainingDataGenerator from lib.training.generator import BatchType, DataGenerator -from lib.utils import FaceswapError, get_folder, get_image_paths, get_tf_version +from lib.utils import FaceswapError, get_folder, get_image_paths from plugins.train._config import Config if T.TYPE_CHECKING: + from collections.abc import Callable, Generator from plugins.train.model._base import ModelBase from lib.config import ConfigValueType -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - logger = logging.getLogger(__name__) # pylint: disable=invalid-name def _get_config(plugin_name: str, - configfile: T.Optional[str] = None) -> T.Dict[str, ConfigValueType]: + configfile: str | None = None) -> dict[str, ConfigValueType]: """ Return the configuration for the requested trainer. Parameters @@ -80,9 +75,9 @@ class TrainerBase(): def __init__(self, model: ModelBase, - images: T.Dict[Literal["a", "b"], T.List[str]], + images: dict[T.Literal["a", "b"], list[str]], batch_size: int, - configfile: T.Optional[str]) -> None: + configfile: str | None) -> None: logger.debug("Initializing %s: (model: '%s', batch_size: %s)", self.__class__.__name__, model, batch_size) self._model = model @@ -111,7 +106,7 @@ class TrainerBase(): self._images) logger.debug("Initialized %s", self.__class__.__name__) - def _get_config(self, configfile: T.Optional[str]) -> T.Dict[str, ConfigValueType]: + def _get_config(self, configfile: str | None) -> dict[str, ConfigValueType]: """ Get the saved training config options. Override any global settings with the setting provided from the model's saved config. @@ -173,10 +168,9 @@ class TrainerBase(): self._samples.toggle_mask_display() def train_one_step(self, - viewer: T.Optional[T.Callable[[np.ndarray, str], None]], - timelapse_kwargs: T.Optional[T.Dict[Literal["input_a", - "input_b", - "output"], str]]) -> None: + viewer: Callable[[np.ndarray, str], None] | None, + timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"], + str] | None) -> None: """ Running training on a batch of images for each side. Triggered from the training cycle in :class:`scripts.train.Train`. @@ -217,7 +211,7 @@ class TrainerBase(): model_inputs, model_targets = self._feeder.get_batch() try: - loss: T.List[float] = self._model.model.train_on_batch(model_inputs, y=model_targets) + loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets) except tf_errors.ResourceExhaustedError as err: msg = ("You do not have enough GPU memory available to train the selected model at " "the selected settings. You can try a number of things:" @@ -236,7 +230,7 @@ class TrainerBase(): self._model.snapshot() self._update_viewers(viewer, timelapse_kwargs) - def _log_tensorboard(self, loss: T.List[float]) -> None: + def _log_tensorboard(self, loss: list[float]) -> None: """ Log current loss to Tensorboard log files Parameters @@ -250,19 +244,18 @@ class TrainerBase(): logs = {log[0]: log[1] for log in zip(self._model.state.loss_names, loss)} - if get_tf_version() > (2, 7): - # Bug in TF 2.8/2.9/2.10 where batch recording got deleted. - # ref: https://github.com/keras-team/keras/issues/16173 - with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa pylint:disable=protected-access,not-context-manager - for name, value in logs.items(): - tf.summary.scalar( - "batch_" + name, - value, - step=self._tensorboard._train_step) # pylint:disable=protected-access - else: - self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs) + # Bug in TF 2.8/2.9/2.10 where batch recording got deleted. + # ref: https://github.com/keras-team/keras/issues/16173 + with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa:E501 pylint:disable=protected-access,not-context-manager + for name, value in logs.items(): + tf.summary.scalar( + "batch_" + name, + value, + step=self._tensorboard._train_step) # pylint:disable=protected-access + # TODO revert this code if fixed in tensorflow + # self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs) - def _collate_and_store_loss(self, loss: T.List[float]) -> T.List[float]: + def _collate_and_store_loss(self, loss: list[float]) -> list[float]: """ Collate the loss into totals for each side. The losses are summed into a total for each side. Loss totals are added to @@ -298,7 +291,7 @@ class TrainerBase(): logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore return combined_loss - def _print_loss(self, loss: T.List[float]) -> None: + def _print_loss(self, loss: list[float]) -> None: """ Outputs the loss for the current iteration to the console. Parameters @@ -318,10 +311,9 @@ class TrainerBase(): "line: %s, error: %s", output, str(err)) def _update_viewers(self, - viewer: T.Optional[T.Callable[[np.ndarray, str], None]], - timelapse_kwargs: T.Optional[T.Dict[Literal["input_a", - "input_b", - "output"], str]]) -> None: + viewer: Callable[[np.ndarray, str], None] | None, + timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"], + str] | None) -> None: """ Update the preview viewer and timelapse output Parameters @@ -371,10 +363,10 @@ class _Feeder(): The configuration for this trainer """ def __init__(self, - images: T.Dict[Literal["a", "b"], T.List[str]], + images: dict[T.Literal["a", "b"], list[str]], model: ModelBase, batch_size: int, - config: T.Dict[str, ConfigValueType]) -> None: + config: dict[str, ConfigValueType]) -> None: logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)", self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size, config) @@ -383,16 +375,16 @@ class _Feeder(): self._batch_size = batch_size self._config = config self._feeds = {side: self._load_generator(side, False).minibatch_ab() - for side in get_args(Literal["a", "b"])} + for side in T.get_args(T.Literal["a", "b"])} self._display_feeds = {"preview": self._set_preview_feed(), "timelapse": {}} logger.debug("Initialized %s:", self.__class__.__name__) def _load_generator(self, - side: Literal["a", "b"], + side: T.Literal["a", "b"], is_display: bool, - batch_size: T.Optional[int] = None, - images: T.Optional[T.List[str]] = None) -> DataGenerator: + batch_size: int | None = None, + images: list[str] | None = None) -> DataGenerator: """ Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder. Parameters @@ -424,7 +416,7 @@ class _Feeder(): self._batch_size if batch_size is None else batch_size) return retval - def _set_preview_feed(self) -> T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]]: + def _set_preview_feed(self) -> dict[T.Literal["a", "b"], Generator[BatchType, None, None]]: """ Set the preview feed for this feeder. Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically @@ -436,10 +428,10 @@ class _Feeder(): The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as value. """ - retval: T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]] = {} + retval: dict[T.Literal["a", "b"], Generator[BatchType, None, None]] = {} num_images = self._config.get("preview_images", 14) assert isinstance(num_images, int) - for side in get_args(Literal["a", "b"]): + for side in T.get_args(T.Literal["a", "b"]): logger.debug("Setting preview feed: (side: '%s')", side) preview_images = min(max(num_images, 2), 16) batchsize = min(len(self._images[side]), preview_images) @@ -448,7 +440,7 @@ class _Feeder(): batch_size=batchsize).minibatch_ab() return retval - def get_batch(self) -> T.Tuple[T.List[T.List[np.ndarray]], ...]: + def get_batch(self) -> tuple[list[list[np.ndarray]], ...]: """ Get the feed data and the targets for each training side for feeding into the model's train function. @@ -459,8 +451,8 @@ class _Feeder(): model_targets: list The targets for the model for each side A and B """ - model_inputs: T.List[T.List[np.ndarray]] = [] - model_targets: T.List[T.List[np.ndarray]] = [] + model_inputs: list[list[np.ndarray]] = [] + model_targets: list[list[np.ndarray]] = [] for side in ("a", "b"): side_feed, side_targets = next(self._feeds[side]) if self._model.config["learn_mask"]: # Add the face mask as it's own target @@ -473,7 +465,7 @@ class _Feeder(): return model_inputs, model_targets def generate_preview(self, is_timelapse: bool = False - ) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]: + ) -> dict[T.Literal["a", "b"], list[np.ndarray]]: """ Generate the images for preview window or timelapse Parameters @@ -490,15 +482,15 @@ class _Feeder(): """ logger.debug("Generating preview (is_timelapse: %s)", is_timelapse) - batchsizes: T.List[int] = [] - feed: T.Dict[Literal["a", "b"], np.ndarray] = {} - samples: T.Dict[Literal["a", "b"], np.ndarray] = {} - masks: T.Dict[Literal["a", "b"], np.ndarray] = {} + batchsizes: list[int] = [] + feed: dict[T.Literal["a", "b"], np.ndarray] = {} + samples: dict[T.Literal["a", "b"], np.ndarray] = {} + masks: dict[T.Literal["a", "b"], np.ndarray] = {} # MyPy can't recurse into nested dicts to get the type :( - iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]], + iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"], self._display_feeds["timelapse" if is_timelapse else "preview"]) - for side in get_args(Literal["a", "b"]): + for side in T.get_args(T.Literal["a", "b"]): side_feed, side_samples = next(iterator[side]) batchsizes.append(len(side_samples[0])) samples[side] = side_samples[0] @@ -513,10 +505,10 @@ class _Feeder(): def compile_sample(self, image_count: int, - feed: T.Dict[Literal["a", "b"], np.ndarray], - samples: T.Dict[Literal["a", "b"], np.ndarray], - masks: T.Dict[Literal["a", "b"], np.ndarray] - ) -> T.Dict[Literal["a", "b"], T.List[np.ndarray]]: + feed: dict[T.Literal["a", "b"], np.ndarray], + samples: dict[T.Literal["a", "b"], np.ndarray], + masks: dict[T.Literal["a", "b"], np.ndarray] + ) -> dict[T.Literal["a", "b"], list[np.ndarray]]: """ Compile the preview samples for display. Parameters @@ -542,8 +534,8 @@ class _Feeder(): num_images = self._config.get("preview_images", 14) assert isinstance(num_images, int) num_images = min(image_count, num_images) - retval: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {} - for side in get_args(Literal["a", "b"]): + retval: dict[T.Literal["a", "b"], list[np.ndarray]] = {} + for side in T.get_args(T.Literal["a", "b"]): logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images) retval[side] = [feed[side][0:num_images], samples[side][0:num_images], @@ -552,7 +544,7 @@ class _Feeder(): return retval def set_timelapse_feed(self, - images: T.Dict[Literal["a", "b"], T.List[str]], + images: dict[T.Literal["a", "b"], list[str]], batch_size: int) -> None: """ Set the time-lapse feed for this feeder. @@ -570,10 +562,10 @@ class _Feeder(): images, batch_size) # MyPy can't recurse into nested dicts to get the type :( - iterator = T.cast(T.Dict[Literal["a", "b"], T.Generator[BatchType, None, None]], + iterator = T.cast(dict[T.Literal["a", "b"], "Generator[BatchType, None, None]"], self._display_feeds["timelapse"]) - for side in get_args(Literal["a", "b"]): + for side in T.get_args(T.Literal["a", "b"]): imgs = images[side] logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs)) @@ -615,7 +607,7 @@ class _Samples(): # pylint:disable=too-few-public-methods self.__class__.__name__, model, coverage_ratio, mask_opacity, mask_color) self._model = model self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"] - self.images: T.Dict[Literal["a", "b"], T.List[np.ndarray]] = {} + self.images: dict[T.Literal["a", "b"], list[np.ndarray]] = {} self._coverage_ratio = coverage_ratio self._mask_opacity = mask_opacity / 100.0 self._mask_color = np.array(hex_to_rgb(mask_color))[..., 2::-1] / 255. @@ -639,8 +631,8 @@ class _Samples(): # pylint:disable=too-few-public-methods A compiled preview image ready for display or saving """ logger.debug("Showing sample") - feeds: T.Dict[Literal["a", "b"], np.ndarray] = {} - for idx, side in enumerate(get_args(Literal["a", "b"])): + feeds: dict[T.Literal["a", "b"], np.ndarray] = {} + for idx, side in enumerate(T.get_args(T.Literal["a", "b"])): feed = self.images[side][0] input_shape = self._model.model.input_shape[idx][1:] if input_shape[0] / feed.shape[1] != 1.0: @@ -653,7 +645,7 @@ class _Samples(): # pylint:disable=too-few-public-methods @classmethod def _resize_sample(cls, - side: Literal["a", "b"], + side: T.Literal["a", "b"], sample: np.ndarray, target_size: int) -> np.ndarray: """ Resize a given image to the target size. @@ -684,7 +676,7 @@ class _Samples(): # pylint:disable=too-few-public-methods logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) return retval - def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> T.Dict[str, np.ndarray]: + def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> dict[str, np.ndarray]: """ Feed the samples to the model and return predictions Parameters @@ -700,7 +692,7 @@ class _Samples(): # pylint:disable=too-few-public-methods List of :class:`numpy.ndarray` of predictions received from the model """ logger.debug("Getting Predictions") - preds: T.Dict[str, np.ndarray] = {} + preds: dict[str, np.ndarray] = {} standard = self._model.model.predict([feed_a, feed_b], verbose=0) swapped = self._model.model.predict([feed_b, feed_a], verbose=0) @@ -719,7 +711,7 @@ class _Samples(): # pylint:disable=too-few-public-methods logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) return preds - def _compile_preview(self, predictions: T.Dict[str, np.ndarray]) -> np.ndarray: + def _compile_preview(self, predictions: dict[str, np.ndarray]) -> np.ndarray: """ Compile predictions and images into the final preview image. Parameters @@ -732,8 +724,8 @@ class _Samples(): # pylint:disable=too-few-public-methods :class:`numpy.ndarry` A compiled preview image ready for display or saving """ - figures: T.Dict[Literal["a", "b"], np.ndarray] = {} - headers: T.Dict[Literal["a", "b"], np.ndarray] = {} + figures: dict[T.Literal["a", "b"], np.ndarray] = {} + headers: dict[T.Literal["a", "b"], np.ndarray] = {} for side, samples in self.images.items(): other_side = "a" if side == "b" else "b" @@ -761,9 +753,9 @@ class _Samples(): # pylint:disable=too-few-public-methods return np.clip(figure * 255, 0, 255).astype('uint8') def _to_full_frame(self, - side: Literal["a", "b"], - samples: T.List[np.ndarray], - predictions: T.List[np.ndarray]) -> T.List[np.ndarray]: + side: T.Literal["a", "b"], + samples: list[np.ndarray], + predictions: list[np.ndarray]) -> list[np.ndarray]: """ Patch targets and prediction images into images of model output size. Parameters @@ -803,10 +795,10 @@ class _Samples(): # pylint:disable=too-few-public-methods return images def _process_full(self, - side: Literal["a", "b"], + side: T.Literal["a", "b"], images: np.ndarray, prediction_size: int, - color: T.Tuple[float, float, float]) -> np.ndarray: + color: tuple[float, float, float]) -> np.ndarray: """ Add a frame overlay to preview images indicating the region of interest. This applies the red border that appears in the preview images. @@ -847,7 +839,7 @@ class _Samples(): # pylint:disable=too-few-public-methods logger.debug("Overlayed background. Shape: %s", images.shape) return images - def _compile_masked(self, faces: T.List[np.ndarray], masks: np.ndarray) -> T.List[np.ndarray]: + def _compile_masked(self, faces: list[np.ndarray], masks: np.ndarray) -> list[np.ndarray]: """ Add the mask to the faces for masked preview. Places an opaque red layer over areas of the face that are masked out. @@ -866,7 +858,7 @@ class _Samples(): # pylint:disable=too-few-public-methods List of :class:`numpy.ndarray` faces with the opaque mask layer applied """ orig_masks = 1 - np.rint(masks) - masks3: T.Union[T.List[np.ndarray], np.ndarray] = [] + masks3: list[np.ndarray] | np.ndarray = [] if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions pred_masks = [1 - np.rint(face[..., -1])[..., None] for face in faces[-2:]] @@ -875,7 +867,7 @@ class _Samples(): # pylint:disable=too-few-public-methods else: masks3 = np.repeat(np.expand_dims(orig_masks, axis=0), 3, axis=0) - retval: T.List[np.ndarray] = [] + retval: list[np.ndarray] = [] alpha = 1.0 - self._mask_opacity for previews, compiled_masks in zip(faces, masks3): overlays = previews.copy() @@ -910,7 +902,7 @@ class _Samples(): # pylint:disable=too-few-public-methods return backgrounds @classmethod - def _get_headers(cls, side: Literal["a", "b"], width: int) -> np.ndarray: + def _get_headers(cls, side: T.Literal["a", "b"], width: int) -> np.ndarray: """ Set header row for the final preview frame Parameters @@ -958,8 +950,8 @@ class _Samples(): # pylint:disable=too-few-public-methods @classmethod def _duplicate_headers(cls, - headers: T.Dict[Literal["a", "b"], np.ndarray], - columns: int) -> T.Dict[Literal["a", "b"], np.ndarray]: + headers: dict[T.Literal["a", "b"], np.ndarray], + columns: int) -> dict[T.Literal["a", "b"], np.ndarray]: """ Duplicate headers for the number of columns displayed for each side. Parameters @@ -1008,7 +1000,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods mask_opacity: int, mask_color: str, feeder: _Feeder, - image_paths: T.Dict[Literal["a", "b"], T.List[str]]) -> None: + image_paths: dict[T.Literal["a", "b"], list[str]]) -> None: logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, " "mask_opacity: %s, mask_color: %s, feeder: %s, image_paths: %s)", self.__class__.__name__, model, coverage_ratio, image_count, mask_opacity, @@ -1042,8 +1034,8 @@ class _Timelapse(): # pylint:disable=too-few-public-methods logger.debug("Time-lapse output set to '%s'", self._output_file) # Rewrite paths to pull from the training images so mask and face data can be accessed - images: T.Dict[Literal["a", "b"], T.List[str]] = {} - for side, input_ in zip(get_args(Literal["a", "b"]), (input_a, input_b)): + images: dict[T.Literal["a", "b"], list[str]] = {} + for side, input_ in zip(T.get_args(T.Literal["a", "b"]), (input_a, input_b)): training_path = os.path.dirname(self._image_paths[side][0]) images[side] = [os.path.join(training_path, os.path.basename(pth)) for pth in get_image_paths(input_)] @@ -1054,7 +1046,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods self._feeder.set_timelapse_feed(images, batchsize) logger.debug("Set up time-lapse") - def output_timelapse(self, timelapse_kwargs: T.Dict[Literal["input_a", + def output_timelapse(self, timelapse_kwargs: dict[T.Literal["input_a", "input_b", "output"], str]) -> None: """ Generate the time-lapse samples and output the created time-lapse to the specified @@ -1068,7 +1060,7 @@ class _Timelapse(): # pylint:disable=too-few-public-methods """ logger.debug("Ouputting time-lapse") if not self._output_file: - self._setup(**T.cast(T.Dict[str, str], timelapse_kwargs)) + self._setup(**T.cast(dict[str, str], timelapse_kwargs)) logger.debug("Getting time-lapse samples") self._samples.images = self._feeder.generate_preview(is_timelapse=True) diff --git a/requirements/_requirements_base.txt b/requirements/_requirements_base.txt index 4054ee36..3d9b36f6 100644 --- a/requirements/_requirements_base.txt +++ b/requirements/_requirements_base.txt @@ -1,19 +1,14 @@ -tqdm>=4.64 +# TESTED WITH PY3.10 +tqdm>=4.65 psutil>=5.9.0 -numexpr>=2.7.3; python_version < '3.9' # >=2.8.0 conflicts in Conda -numexpr>=2.8.3; python_version >= '3.9' -opencv-python>=4.6.0.0 -pillow>=9.2.0 -scikit-learn==1.0.2; python_version < '3.9' # AMD needs version 1.0.2 and 1.1.0 not available in Python 3.7 -scikit-learn>=1.1.0; python_version >= '3.9' +numexpr>=2.8.4 +numpy>=1.25.0 +opencv-python>=4.7.0.0 +pillow>=9.4.0 +scikit-learn>=1.2.2 fastcluster>=1.2.6 -matplotlib>=3.4.3,<3.6.0; python_version < '3.9' # >=3.5.0 conflicts in Conda -matplotlib>=3.5.1,<3.6.0; python_version >= '3.9' -imageio>=2.19.3 -imageio-ffmpeg>=0.4.7 +matplotlib>=3.7.1 +imageio>=2.26.0 +imageio-ffmpeg>=0.4.8 ffmpy>=0.3.0 -# Exclude badly numbered Python2 version of nvidia-ml-py -nvidia-ml-py>=11.515,<300 -tensorflow-probability<0.17 -typing-extensions>=4.0.0 pywin32>=228 ; sys_platform == "win32" diff --git a/requirements/requirements_apple_silicon.txt b/requirements/requirements_apple_silicon.txt index 61251120..5732337e 100644 --- a/requirements/requirements_apple_silicon.txt +++ b/requirements/requirements_apple_silicon.txt @@ -1,11 +1,7 @@ -protobuf>= 3.19.0,<3.20.0 # TF has started pulling in incompatible protobuf -# Pinned TF probability doesn't work with numpy >= 1.24 -numpy>=1.21.0,<1.24.0; python_version < '3.8' -numpy>=1.22.0,<1.24.0; python_version >= '3.8' -tensorflow-macos>=2.8.0,<2.11.0 -tensorflow-deps>=2.8.0,<2.11.0 -tensorflow-metal>=0.4.0,<0.7.0 -libblas # Conda only +-r _requirements_base.txt +tensorflow-macos>=2.10.0,<2.11.0 +tensorflow-deps>=2.10.0,<2.11.0 +tensorflow-metal>=0.6.0,<0.7.0 # These next 2 should have been installed, but some users complain of errors decorator cloudpickle diff --git a/requirements/requirements_cpu.txt b/requirements/requirements_cpu.txt index 52b3315f..873e3d35 100644 --- a/requirements/requirements_cpu.txt +++ b/requirements/requirements_cpu.txt @@ -1,5 +1,2 @@ -r _requirements_base.txt -# Pinned TF probability doesn't work with numpy >= 1.24 -numpy>=1.21.0,<1.24.0; python_version < '3.8' -numpy>=1.22.0,<1.24.0; python_version >= '3.8' -tensorflow-cpu>=2.7.0,<2.11.0 +tensorflow-cpu>=2.10.0,<2.11.0 diff --git a/requirements/requirements_directml.txt b/requirements/requirements_directml.txt index 9c4319ca..d7e0dbc2 100644 --- a/requirements/requirements_directml.txt +++ b/requirements/requirements_directml.txt @@ -1,7 +1,4 @@ -r _requirements_base.txt -# Pinned TF probability doesn't work with numpy >= 1.24 -numpy>=1.21.0,<1.24.0; python_version < '3.8' -numpy>=1.22.0,<1.24.0; python_version >= '3.8' tensorflow-cpu>=2.10.0,<2.11.0 tensorflow-directml-plugin comtypes diff --git a/requirements/requirements_nvidia.txt b/requirements/requirements_nvidia.txt index 829b3a7a..f3a0bc93 100644 --- a/requirements/requirements_nvidia.txt +++ b/requirements/requirements_nvidia.txt @@ -1,6 +1,5 @@ -r _requirements_base.txt -# Pinned TF probability doesn't work with numpy >= 1.24 -numpy>=1.21.0,<1.24.0; python_version < '3.8' -numpy>=1.22.0,<1.24.0; python_version >= '3.8' -tensorflow-gpu>=2.7.0,<2.11.0 +# Exclude badly numbered Python2 version of nvidia-ml-py +nvidia-ml-py>=11.525,<300 pynvx==1.0.0 ; sys_platform == "darwin" +tensorflow>=2.10.0,<2.11.0 diff --git a/requirements/requirements_rocm.txt b/requirements/requirements_rocm.txt index e7bfc6c0..b23ce015 100644 --- a/requirements/requirements_rocm.txt +++ b/requirements/requirements_rocm.txt @@ -1,5 +1,2 @@ -r _requirements_base.txt -# Pinned TF probability doesn't work with numpy >= 1.24 -numpy>=1.21.0,<1.24.0; python_version < '3.8' -numpy>=1.22.0,<1.24.0; python_version >= '3.8' tensorflow-rocm>=2.10.0,<2.11.0 diff --git a/scripts/convert.py b/scripts/convert.py index 374bf7fa..89cc56bf 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -26,13 +26,9 @@ from lib.utils import FaceswapError, get_folder, get_image_paths from plugins.extract.pipeline import Extractor, ExtractMedia from plugins.plugin_loader import PluginLoader -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - if T.TYPE_CHECKING: from argparse import Namespace + from collections.abc import Callable from plugins.convert.writer._base import Output from plugins.train.model._base import ModelBase from lib.align.aligned_face import CenteringType @@ -61,8 +57,8 @@ class ConvertItem: The swapped faces returned from the model's predict function """ inbound: ExtractMedia - feed_faces: T.List[AlignedFace] = field(default_factory=list) - reference_faces: T.List[AlignedFace] = field(default_factory=list) + feed_faces: list[AlignedFace] = field(default_factory=list) + reference_faces: list[AlignedFace] = field(default_factory=list) swapped_faces: np.ndarray = np.array([]) @@ -307,8 +303,8 @@ class DiskIO(): # Extractor for on the fly detection self._extractor = self._load_extractor() - self._queues: T.Dict[Literal["load", "save"], EventQueue] = {} - self._threads: T.Dict[Literal["load", "save"], MultiThread] = {} + self._queues: dict[T.Literal["load", "save"], EventQueue] = {} + self._threads: dict[T.Literal["load", "save"], MultiThread] = {} self._init_threads() logger.debug("Initialized %s", self.__class__.__name__) @@ -324,13 +320,13 @@ class DiskIO(): return self._writer.config.get("draw_transparent", False) @property - def pre_encode(self) -> T.Optional[T.Callable[[np.ndarray], T.List[bytes]]]: + def pre_encode(self) -> Callable[[np.ndarray], list[bytes]] | None: """ python function: Selected writer's pre-encode function, if it has one, otherwise ``None`` """ dummy = np.zeros((20, 20, 3), dtype="uint8") test = self._writer.pre_encode(dummy) - retval: T.Optional[T.Callable[[np.ndarray], - T.List[bytes]]] = None if test is None else self._writer.pre_encode + retval: Callable[[np.ndarray], + list[bytes]] | None = None if test is None else self._writer.pre_encode logger.debug("Writer pre_encode function: %s", retval) return retval @@ -384,7 +380,7 @@ class DiskIO(): return PluginLoader.get_converter("writer", self._args.writer)(*args, configfile=configfile) - def _get_frame_ranges(self) -> T.Optional[T.List[T.Tuple[int, int]]]: + def _get_frame_ranges(self) -> list[tuple[int, int]] | None: """ Obtain the frame ranges that are to be converted. If frame ranges have been specified, then split the command line formatted arguments into @@ -422,7 +418,7 @@ class DiskIO(): logger.debug("frame ranges: %s", retval) return retval - def _load_extractor(self) -> T.Optional[Extractor]: + def _load_extractor(self) -> Extractor | None: """ Load the CV2-DNN Face Extractor Chain. For On-The-Fly conversion we use a CPU based extractor to avoid stacking the GPU. @@ -467,12 +463,12 @@ class DiskIO(): Creates the load and save queues and the load and save threads. Starts the threads. """ logger.debug("Initializing DiskIO Threads") - for task in get_args(Literal["load", "save"]): + for task in T.get_args(T.Literal["load", "save"]): self._add_queue(task) self._start_thread(task) logger.debug("Initialized DiskIO Threads") - def _add_queue(self, task: Literal["load", "save"]) -> None: + def _add_queue(self, task: T.Literal["load", "save"]) -> None: """ Add the queue to queue_manager and to :attr:`self._queues` for the given task. Parameters @@ -490,7 +486,7 @@ class DiskIO(): self._queues[task] = queue_manager.get_queue(q_name) logger.debug("Added queue for task: '%s'", task) - def _start_thread(self, task: Literal["load", "save"]) -> None: + def _start_thread(self, task: T.Literal["load", "save"]) -> None: """ Create the thread for the given task, add it it :attr:`self._threads` and start it. Parameters @@ -571,7 +567,7 @@ class DiskIO(): logger.trace("idx: %s, skipframe: %s", idx, skipframe) # type: ignore return skipframe - def _get_detected_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]: + def _get_detected_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]: """ Return the detected faces for the given image. If we have an alignments file, then the detected faces are created from that file. If @@ -597,7 +593,7 @@ class DiskIO(): logger.trace("Got %s faces for: '%s'", len(detected_faces), filename) # type:ignore return detected_faces - def _alignments_faces(self, frame_name: str, image: np.ndarray) -> T.List[DetectedFace]: + def _alignments_faces(self, frame_name: str, image: np.ndarray) -> list[DetectedFace]: """ Return detected faces from an alignments file. Parameters @@ -644,7 +640,7 @@ class DiskIO(): tqdm.write(f"No alignment found for {frame_name}, skipping") return have_alignments - def _detect_faces(self, filename: str, image: np.ndarray) -> T.List[DetectedFace]: + def _detect_faces(self, filename: str, image: np.ndarray) -> list[DetectedFace]: """ Extract the face from a frame for On-The-Fly conversion. Pulls detected faces out of the Extraction pipeline. @@ -779,7 +775,7 @@ class Predict(): """ int: The size in pixels of the Faceswap model output. """ return self._sizes["output"] - def _get_io_sizes(self) -> T.Dict[str, int]: + def _get_io_sizes(self) -> dict[str, int]: """ Obtain the input size and output size of the model. Returns @@ -896,9 +892,9 @@ class Predict(): """ faces_seen = 0 consecutive_no_faces = 0 - batch: T.List[ConvertItem] = [] + batch: list[ConvertItem] = [] while True: - item: T.Union[Literal["EOF"], ConvertItem] = self._in_queue.get() + item: T.Literal["EOF"] | ConvertItem = self._in_queue.get() if item == "EOF": logger.debug("EOF Received") if batch: # Process out any remaining items @@ -938,7 +934,7 @@ class Predict(): self._out_queue.put("EOF") logger.debug("Load queue complete") - def _process_batch(self, batch: T.List[ConvertItem], faces_seen: int): + def _process_batch(self, batch: list[ConvertItem], faces_seen: int): """ Predict faces on the given batch of images and queue out to patch thread Parameters @@ -1001,7 +997,7 @@ class Predict(): logger.trace("Loaded aligned faces: '%s'", item.inbound.filename) # type:ignore @staticmethod - def _compile_feed_faces(feed_faces: T.List[AlignedFace]) -> np.ndarray: + def _compile_feed_faces(feed_faces: list[AlignedFace]) -> np.ndarray: """ Compile a batch of faces for feeding into the Predictor. Parameters @@ -1020,7 +1016,7 @@ class Predict(): logger.trace("Compiled Feed faces. Shape: %s", retval.shape) # type:ignore return retval - def _predict(self, feed_faces: np.ndarray, batch_size: T.Optional[int] = None) -> np.ndarray: + def _predict(self, feed_faces: np.ndarray, batch_size: int | None = None) -> np.ndarray: """ Run the Faceswap models' prediction function. Parameters @@ -1045,7 +1041,7 @@ class Predict(): logger.trace("Input shape(s): %s", [item.shape for item in feed]) # type:ignore inbound = self._model.model.predict(feed, verbose=0, batch_size=batch_size) - predicted: T.List[np.ndarray] = inbound if isinstance(inbound, list) else [inbound] + predicted: list[np.ndarray] = inbound if isinstance(inbound, list) else [inbound] if self._model.color_order.lower() == "rgb": predicted[0] = predicted[0][..., ::-1] @@ -1062,7 +1058,7 @@ class Predict(): logger.trace("Final shape: %s", retval.shape) # type:ignore return retval - def _queue_out_frames(self, batch: T.List[ConvertItem], swapped_faces: np.ndarray) -> None: + def _queue_out_frames(self, batch: list[ConvertItem], swapped_faces: np.ndarray) -> None: """ Compile the batch back to original frames and put to the Out Queue. For batching, faces are split away from their frames. This compiles all detected faces @@ -1108,7 +1104,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods """ def __init__(self, arguments: Namespace, - input_images: T.List[np.ndarray], + input_images: list[np.ndarray], alignments: Alignments) -> None: logger.debug("Initializing %s", self.__class__.__name__) self._args = arguments @@ -1131,7 +1127,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods self._alignments.filter_faces(accept_dict, filter_out=False) logger.info("Faces filtered out: %s", pre_face_count - self._alignments.faces_count) - def _get_face_metadata(self) -> T.Dict[str, T.List[int]]: + def _get_face_metadata(self) -> dict[str, list[int]]: """ Check for the existence of an aligned directory for identifying which faces in the target frames should be swapped. If it exists, scan the folder for face's metadata @@ -1140,7 +1136,7 @@ class OptionalActions(): # pylint:disable=too-few-public-methods dict Dictionary of source frame names with a list of associated face indices to be skipped """ - retval: T.Dict[str, T.List[int]] = {} + retval: dict[str, list[int]] = {} input_aligned_dir = self._args.input_aligned_dir if input_aligned_dir is None: diff --git a/scripts/extract.py b/scripts/extract.py index f20dbb8f..44a7434b 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -2,13 +2,13 @@ """ Main entry point to the extract process of FaceSwap """ from __future__ import annotations - import logging import os import sys +import typing as T + from argparse import Namespace from multiprocessing import Process -from typing import List, Dict, Optional, Tuple, TYPE_CHECKING, Union import numpy as np from tqdm import tqdm @@ -20,7 +20,7 @@ from lib.utils import get_folder, _image_extensions, _video_extensions from plugins.extract.pipeline import Extractor, ExtractMedia from scripts.fsmedia import Alignments, PostProcess, finalize -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.alignments import PNGHeaderAlignmentsDict # tqdm.monitor_interval = 0 # workaround for TqdmSynchronisationWarning # TODO? @@ -75,7 +75,7 @@ class Extract(): # pylint:disable=too-few-public-methods self._args.nfilter, self._extractor) - def _get_input_locations(self) -> List[str]: + def _get_input_locations(self) -> list[str]: """ Obtain the full path to input locations. Will be a list of locations if batch mode is selected, or a containing a single location if batch mode is not selected. @@ -194,8 +194,8 @@ class Filter(): """ def __init__(self, threshold: float, - filter_files: Optional[List[str]], - nfilter_files: Optional[List[str]], + filter_files: list[str] | None, + nfilter_files: list[str] | None, extractor: Extractor) -> None: logger.debug("Initializing %s: (threshold: %s, filter_files: %s, nfilter_files: %s " "extractor: %s)", self.__class__.__name__, threshold, filter_files, @@ -208,8 +208,8 @@ class Filter(): logger.debug("Filter not selected. Exiting %s", self.__class__.__name__) return - self._embeddings: List[np.ndarray] = [np.array([]) for _ in self._filter_files] - self._nembeddings: List[np.ndarray] = [np.array([]) for _ in self._nfilter_files] + self._embeddings: list[np.ndarray] = [np.array([]) for _ in self._filter_files] + self._nembeddings: list[np.ndarray] = [np.array([]) for _ in self._nfilter_files] self._extractor = extractor self._get_embeddings() @@ -243,7 +243,7 @@ class Filter(): return retval @classmethod - def _files_from_folder(cls, input_location: List[str]) -> List[str]: + def _files_from_folder(cls, input_location: list[str]) -> list[str]: """ Test whether the input location is a folder and if so, return the list of contained image files, otherwise return the original input location @@ -274,8 +274,8 @@ class Filter(): return retval def _validate_inputs(self, - filter_files: Optional[List[str]], - nfilter_files: Optional[List[str]]) -> Tuple[List[str], List[str]]: + filter_files: list[str] | None, + nfilter_files: list[str] | None) -> tuple[list[str], list[str]]: """ Validates that the given filter/nfilter files exist, are image files and are unique Parameters @@ -293,7 +293,7 @@ class Filter(): List of full paths to nfilter files """ error = False - retval: List[List[str]] = [] + retval: list[list[str]] = [] for files in (filter_files, nfilter_files): filt_files = [] if files is None else self._files_from_folder(files) @@ -322,7 +322,7 @@ class Filter(): return filters, nfilters @classmethod - def _identity_from_extracted(cls, filename) -> Tuple[np.ndarray, bool]: + def _identity_from_extracted(cls, filename) -> tuple[np.ndarray, bool]: """ Test whether the given image is a faceswap extracted face and contains identity information. If so, return the identity embedding @@ -404,7 +404,7 @@ class Filter(): embeddings[idx] = identities return - def _identity_from_extractor(self, file_list: List[str], aligned: List[str]) -> None: + def _identity_from_extractor(self, file_list: list[str], aligned: list[str]) -> None: """ Obtain the identity embeddings from the extraction pipeline Parameters @@ -425,7 +425,7 @@ class Filter(): for phase in range(self._extractor.passes): is_final = self._extractor.final_pass - detected_faces: Dict[str, ExtractMedia] = {} + detected_faces: dict[str, ExtractMedia] = {} self._extractor.launch() desc = "Obtaining reference face Identity" if self._extractor.passes > 1: @@ -450,8 +450,8 @@ class Filter(): def _get_embeddings(self) -> None: """ Obtain the embeddings for the given filter lists """ - needs_extraction: List[str] = [] - aligned: List[str] = [] + needs_extraction: list[str] = [] + aligned: list[str] = [] for files, embed in zip((self._filter_files, self._nfilter_files), (self._embeddings, self._nembeddings)): @@ -494,14 +494,14 @@ class PipelineLoader(): image files that exist in :attr:`path` that are aligned faceswap images """ def __init__(self, - path: Union[str, List[str]], + path: str | list[str], extractor: Extractor, - aligned_filenames: Optional[List[str]] = None) -> None: + aligned_filenames: list[str] | None = None) -> None: logger.debug("Initializing %s: (path: %s, extractor: %s, aligned_filenames: %s)", self.__class__.__name__, path, extractor, aligned_filenames) self._images = ImagesLoader(path, fast_count=True) self._extractor = extractor - self._threads: List[MultiThread] = [] + self._threads: list[MultiThread] = [] self._aligned_filenames = [] if aligned_filenames is None else aligned_filenames logger.debug("Initialized %s", self.__class__.__name__) @@ -512,7 +512,7 @@ class PipelineLoader(): return self._images.is_video @property - def file_list(self) -> List[str]: + def file_list(self) -> list[str]: """ list: A full list of files in the source location. If the input is a video then this is a list of dummy filenames as corresponding to an alignments file """ return self._images.file_list @@ -523,7 +523,7 @@ class PipelineLoader(): items that are to be skipped from the :attr:`skip_list`)""" return self._images.process_count - def add_skip_list(self, skip_list: List[int]) -> None: + def add_skip_list(self, skip_list: list[int]) -> None: """ Add a skip list to the :class:`ImagesLoader` Parameters @@ -538,7 +538,7 @@ class PipelineLoader(): """ Launch the image loading pipeline """ self._threaded_redirector("load") - def reload(self, detected_faces: Dict[str, ExtractMedia]) -> None: + def reload(self, detected_faces: dict[str, ExtractMedia]) -> None: """ Reload images for multiple pipeline passes """ self._threaded_redirector("reload", (detected_faces, )) @@ -552,7 +552,7 @@ class PipelineLoader(): for thread in self._threads: thread.join() - def _threaded_redirector(self, task: str, io_args: Optional[tuple] = None) -> None: + def _threaded_redirector(self, task: str, io_args: tuple | None = None) -> None: """ Redirect image input/output tasks to relevant queues in background thread Parameters @@ -587,7 +587,7 @@ class PipelineLoader(): load_queue.put("EOF") logger.debug("Load Images: Complete") - def _reload(self, detected_faces: Dict[str, ExtractMedia]) -> None: + def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None: """ Reload the images and pair to detected face When the extraction pipeline is running in serial mode, images are reloaded from disk, @@ -652,7 +652,7 @@ class _Extract(): # pylint:disable=too-few-public-methods logger.debug("Initialized %s", self.__class__.__name__) @property - def _save_interval(self) -> Optional[int]: + def _save_interval(self) -> int | None: """ int: The number of frames to be processed between each saving of the alignments file if it has been provided, otherwise ``None`` """ if hasattr(self._args, "save_interval"): @@ -718,7 +718,7 @@ class _Extract(): # pylint:disable=too-few-public-methods as_bytes=True) for phase in range(self._extractor.passes): is_final = self._extractor.final_pass - detected_faces: Dict[str, ExtractMedia] = {} + detected_faces: dict[str, ExtractMedia] = {} self._extractor.launch() self._loader.check_thread_error() ph_desc = "Extraction" if self._extractor.passes == 1 else self._extractor.phase_text @@ -774,7 +774,7 @@ class _Extract(): # pylint:disable=too-few-public-methods if not self._verify_output and faces_count > 1: self._verify_output = True - def _output_faces(self, saver: Optional[ImagesSaver], extract_media: ExtractMedia) -> None: + def _output_faces(self, saver: ImagesSaver | None, extract_media: ExtractMedia) -> None: """ Output faces to save thread Set the face filename based on the frame name and put the face to the @@ -798,14 +798,14 @@ class _Extract(): # pylint:disable=too-few-public-methods output_filename = f"{filename}_{real_face_id}.png" aligned = face.aligned.face assert aligned is not None - meta: PNGHeaderDict = dict( - alignments=face.to_png_meta(), - source=dict(alignments_version=self._alignments.version, - original_filename=output_filename, - face_index=real_face_id, - source_filename=os.path.basename(extract_media.filename), - source_is_video=self._loader.is_video, - source_frame_dims=extract_media.image_size)) + meta: PNGHeaderDict = { + "alignments": face.to_png_meta(), + "source": {"alignments_version": self._alignments.version, + "original_filename": output_filename, + "face_index": real_face_id, + "source_filename": os.path.basename(extract_media.filename), + "source_is_video": self._loader.is_video, + "source_frame_dims": extract_media.image_size}} image = encode_image(aligned, ".png", metadata=meta) sub_folder = extract_media.sub_folders[face_id] @@ -820,6 +820,6 @@ class _Extract(): # pylint:disable=too-few-public-methods continue final_faces.append(face.to_alignment()) - self._alignments.data[os.path.basename(extract_media.filename)] = dict(faces=final_faces, - video_meta={}) + self._alignments.data[os.path.basename(extract_media.filename)] = {"faces": final_faces, + "video_meta": {}} del extract_media diff --git a/scripts/fsmedia.py b/scripts/fsmedia.py index 92e95eac..0b4ea078 100644 --- a/scripts/fsmedia.py +++ b/scripts/fsmedia.py @@ -5,12 +5,13 @@ Holds the classes for the 2 main Faceswap 'media' objects: Images and Alignments Holds optional pre/post processing functions for convert and extract. """ - +from __future__ import annotations import logging import os import sys -from typing import (Any, cast, Dict, Generator, Iterator, List, - Optional, Tuple, TYPE_CHECKING, Union) +import typing as T + +from collections.abc import Iterator import cv2 import numpy as np @@ -20,7 +21,8 @@ from lib.align import Alignments as AlignmentsBase, get_centered_size from lib.image import count_frames, read_image from lib.utils import (camel_case_split, get_image_paths, _video_extensions) -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from argparse import Namespace from lib.align import AlignedFace from plugins.extract.pipeline import ExtractMedia @@ -69,7 +71,7 @@ class Alignments(AlignmentsBase): Default: False """ def __init__(self, - arguments: "Namespace", + arguments: Namespace, is_extract: bool, input_is_video: bool = False) -> None: logger.debug("Initializing %s: (is_extract: %s, input_is_video: %s)", @@ -80,7 +82,7 @@ class Alignments(AlignmentsBase): super().__init__(folder, filename=filename) logger.debug("Initialized %s", self.__class__.__name__) - def _set_folder_filename(self, input_is_video: bool) -> Tuple[str, str]: + def _set_folder_filename(self, input_is_video: bool) -> tuple[str, str]: """ Return the folder and the filename for the alignments file. If the input is a video, the alignments file will be stored in the same folder @@ -115,7 +117,7 @@ class Alignments(AlignmentsBase): logger.debug("Setting Alignments: (folder: '%s' filename: '%s')", folder, filename) return folder, filename - def _load(self) -> Dict[str, Any]: + def _load(self) -> dict[str, T.Any]: """ Override the parent :func:`~lib.align.Alignments._load` to handle skip existing frames and faces on extract. @@ -128,7 +130,7 @@ class Alignments(AlignmentsBase): Any alignments that have already been extracted if skip existing has been selected otherwise an empty dictionary """ - data: Dict[str, Any] = {} + data: dict[str, T.Any] = {} if not self._is_extract and not self.have_alignments_file: return data if not self._is_extract: @@ -170,7 +172,7 @@ class Images(): arguments: :class:`argparse.Namespace` The command line arguments that were passed to Faceswap """ - def __init__(self, arguments: "Namespace") -> None: + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s", self.__class__.__name__) self._args = arguments self._is_video = self._check_input_folder() @@ -184,7 +186,7 @@ class Images(): return self._is_video @property - def input_images(self) -> Union[str, List[str]]: + def input_images(self) -> str | list[str]: """str or list: Path to the video file if the input is a video otherwise list of image paths. """ return self._input_images @@ -228,7 +230,7 @@ class Images(): retval = False return retval - def _get_input_images(self) -> Union[str, List[str]]: + def _get_input_images(self) -> str | list[str]: """ Return the list of images or path to video file that is to be processed. Returns @@ -243,7 +245,7 @@ class Images(): return input_images - def load(self) -> Generator[Tuple[str, np.ndarray], None, None]: + def load(self) -> Generator[tuple[str, np.ndarray], None, None]: """ Generator to load frames from a folder of images or from a video file. Yields @@ -257,7 +259,7 @@ class Images(): for filename, image in iterator(): yield filename, image - def _load_disk_frames(self) -> Generator[Tuple[str, np.ndarray], None, None]: + def _load_disk_frames(self) -> Generator[tuple[str, np.ndarray], None, None]: """ Generator to load frames from a folder of images. Yields @@ -274,7 +276,7 @@ class Images(): continue yield filename, image - def _load_video_frames(self) -> Generator[Tuple[str, np.ndarray], None, None]: + def _load_video_frames(self) -> Generator[tuple[str, np.ndarray], None, None]: """ Generator to load frames from a video file. Yields @@ -287,7 +289,7 @@ class Images(): logger.debug("Input is video. Capturing frames") vidname = os.path.splitext(os.path.basename(self._args.input_dir))[0] reader = imageio.get_reader(self._args.input_dir, "ffmpeg") # type:ignore[arg-type] - for i, frame in enumerate(cast(Iterator[np.ndarray], reader)): + for i, frame in enumerate(T.cast(Iterator[np.ndarray], reader)): # Convert to BGR for cv2 compatibility frame = frame[:, :, ::-1] filename = f"{vidname}_{i + 1:06d}.png" @@ -354,13 +356,13 @@ class PostProcess(): # pylint:disable=too-few-public-methods arguments: :class:`argparse.Namespace` The command line arguments that were passed to Faceswap """ - def __init__(self, arguments: "Namespace") -> None: + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s", self.__class__.__name__) self._args = arguments self._actions = self._set_actions() logger.debug("Initialized %s", self.__class__.__name__) - def _set_actions(self) -> List["PostProcessAction"]: + def _set_actions(self) -> list[PostProcessAction]: """ Compile the requested actions to be performed into a list Returns @@ -369,7 +371,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods The list of :class:`PostProcessAction` to be performed """ postprocess_items = self._get_items() - actions: List["PostProcessAction"] = [] + actions: list["PostProcessAction"] = [] for action, options in postprocess_items.items(): options = {} if options is None else options args = options.get("args", tuple()) @@ -387,7 +389,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods return actions - def _get_items(self) -> Dict[str, Optional[Dict[str, Union[tuple, dict]]]]: + def _get_items(self) -> dict[str, dict[str, tuple | dict] | None]: """ Check the passed in command line arguments for requested actions, For any requested actions, add the item to the actions list along with @@ -399,7 +401,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods The name of the action to be performed as the key. Any action specific arguments and keyword arguments as the value. """ - postprocess_items: Dict[str, Optional[Dict[str, Union[tuple, dict]]]] = {} + postprocess_items: dict[str, dict[str, tuple | dict] | None] = {} # Debug Landmarks if (hasattr(self._args, 'debug_landmarks') and self._args.debug_landmarks): postprocess_items["DebugLandmarks"] = None @@ -407,7 +409,7 @@ class PostProcess(): # pylint:disable=too-few-public-methods logger.debug("Postprocess Items: %s", postprocess_items) return postprocess_items - def do_actions(self, extract_media: "ExtractMedia") -> None: + def do_actions(self, extract_media: ExtractMedia) -> None: """ Perform the requested optional post-processing actions on the given image. Parameters @@ -451,7 +453,7 @@ class PostProcessAction(): # pylint: disable=too-few-public-methods otherwise ``False`` """ return self._valid - def process(self, extract_media: "ExtractMedia") -> None: + def process(self, extract_media: ExtractMedia) -> None: """ Override for specific post processing action Parameters @@ -487,8 +489,8 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho def _border_text(self, image: np.ndarray, text: str, - color: Tuple[int, int, int], - position: Tuple[int, int]) -> None: + color: tuple[int, int, int], + position: tuple[int, int]) -> None: """ Create text on an image with a black border Parameters @@ -515,7 +517,7 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho lineType=cv2.LINE_AA) thickness //= 2 - def _annotate_face_box(self, face: "AlignedFace") -> None: + def _annotate_face_box(self, face: AlignedFace) -> None: """ Annotate the face extract box and print the original size in pixels face: :class:`~lib.align.AlignedFace` @@ -543,7 +545,7 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho self._border_text(text_img, text, color, (pos_x, pos_y)) cv2.addWeighted(text_img, 0.75, face.face, 0.25, 0, face.face) - def _print_stats(self, face: "AlignedFace") -> None: + def _print_stats(self, face: AlignedFace) -> None: """ Print various metrics on the output face images Parameters @@ -571,7 +573,7 @@ class DebugLandmarks(PostProcessAction): # pylint: disable=too-few-public-metho # Apply text to face cv2.addWeighted(text_image, 0.75, face.face, 0.25, 0, face.face) - def process(self, extract_media: "ExtractMedia") -> None: + def process(self, extract_media: ExtractMedia) -> None: """ Draw landmarks on a face. Parameters diff --git a/scripts/train.py b/scripts/train.py index 62be0aa0..ddadba07 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,13 +1,13 @@ #!/usr/bin python3 """ Main entry point to the training process of FaceSwap """ - +from __future__ import annotations import logging import os import sys +import typing as T from time import sleep from threading import Event -from typing import cast, Callable, Dict, List, Optional, TYPE_CHECKING import cv2 import numpy as np @@ -21,13 +21,9 @@ from lib.utils import (get_folder, get_image_paths, FaceswapError, _image_extensions) from plugins.plugin_loader import PluginLoader -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: import argparse + from collections.abc import Callable from plugins.train.model._base import ModelBase from plugins.train.trainer._base import TrainerBase @@ -50,7 +46,7 @@ class Train(): # pylint:disable=too-few-public-methods The arguments to be passed to the training process as generated from Faceswap's command line arguments """ - def __init__(self, arguments: "argparse.Namespace") -> None: + def __init__(self, arguments: argparse.Namespace) -> None: logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments) self._args = arguments self._handle_deprecations() @@ -63,9 +59,9 @@ class Train(): # pylint:disable=too-few-public-methods self._timelapse = self._set_timelapse() gui_cache = os.path.join( os.path.realpath(os.path.dirname(sys.argv[0])), "lib", "gui", ".cache") - self._gui_triggers: Dict[Literal["mask", "refresh"], str] = dict( - mask=os.path.join(gui_cache, ".preview_mask_toggle"), - refresh=os.path.join(gui_cache, ".preview_trigger")) + self._gui_triggers: dict[T.Literal["mask", "refresh"], str] = { + "mask": os.path.join(gui_cache, ".preview_mask_toggle"), + "refresh": os.path.join(gui_cache, ".preview_trigger")} self._stop: bool = False self._save_now: bool = False self._preview = PreviewInterface(self._args.preview) @@ -76,7 +72,7 @@ class Train(): # pylint:disable=too-few-public-methods """ Handle the update of deprecated arguments and output warnings. """ return - def _get_images(self) -> Dict[Literal["a", "b"], List[str]]: + def _get_images(self) -> dict[T.Literal["a", "b"], list[str]]: """ Check the image folders exist and contains valid extracted faces. Obtain image paths. Returns @@ -88,7 +84,7 @@ class Train(): # pylint:disable=too-few-public-methods logger.debug("Getting image paths") images = {} for side in ("a", "b"): - side = cast(Literal["a", "b"], side) + side = T.cast(T.Literal["a", "b"], side) image_dir = getattr(self._args, f"input_{side}") if not os.path.isdir(image_dir): logger.error("Error: '%s' does not exist", image_dir) @@ -117,7 +113,7 @@ class Train(): # pylint:disable=too-few-public-methods return images @classmethod - def _validate_image_counts(cls, images: Dict[Literal["a", "b"], List[str]]) -> None: + def _validate_image_counts(cls, images: dict[T.Literal["a", "b"], list[str]]) -> None: """ Validate that there are sufficient images to commence training without raising an error. @@ -145,7 +141,7 @@ class Train(): # pylint:disable=too-few-public-methods "Results are likely to be poor.") logger.warning(msg) - def _set_timelapse(self) -> Dict[Literal["input_a", "input_b", "output"], str]: + def _set_timelapse(self) -> dict[T.Literal["input_a", "input_b", "output"], str]: """ Set time-lapse paths if requested. Returns @@ -168,7 +164,7 @@ class Train(): # pylint:disable=too-few-public-methods timelapse_output = get_folder(self._args.timelapse_output) for side in ("a", "b"): - side = cast(Literal["a", "b"], side) + side = T.cast(T.Literal["a", "b"], side) folder = getattr(self._args, f"timelapse_input_{side}") if folder is not None and not os.path.isdir(folder): raise FaceswapError(f"The Timelapse path '{folder}' does not exist") @@ -190,10 +186,10 @@ class Train(): # pylint:disable=too-few-public-methods raise FaceswapError(f"All images in the Timelapse folder '{folder}' must exist in " f"the training folder '{training_folder}'") - TKey = Literal["input_a", "input_b", "output"] - kwargs = {cast(TKey, "input_a"): self._args.timelapse_input_a, - cast(TKey, "input_b"): self._args.timelapse_input_b, - cast(TKey, "output"): timelapse_output} + TKey = T.Literal["input_a", "input_b", "output"] + kwargs = {T.cast(TKey, "input_a"): self._args.timelapse_input_a, + T.cast(TKey, "input_b"): self._args.timelapse_input_b, + T.cast(TKey, "output"): timelapse_output} logger.debug("Timelapse enabled: %s", kwargs) return kwargs @@ -274,7 +270,7 @@ class Train(): # pylint:disable=too-few-public-methods except Exception as err: raise err - def _load_model(self) -> "ModelBase": + def _load_model(self) -> ModelBase: """ Load the model requested for training. Returns @@ -284,7 +280,7 @@ class Train(): # pylint:disable=too-few-public-methods """ logger.debug("Loading Model") model_dir = get_folder(self._args.model_dir) - model: "ModelBase" = PluginLoader.get_model(self._args.trainer)( + model: ModelBase = PluginLoader.get_model(self._args.trainer)( model_dir, self._args, predict=False) @@ -292,7 +288,7 @@ class Train(): # pylint:disable=too-few-public-methods logger.debug("Loaded Model") return model - def _load_trainer(self, model: "ModelBase") -> "TrainerBase": + def _load_trainer(self, model: ModelBase) -> TrainerBase: """ Load the trainer requested for training. Parameters @@ -307,14 +303,14 @@ class Train(): # pylint:disable=too-few-public-methods """ logger.debug("Loading Trainer") base = PluginLoader.get_trainer(model.trainer) - trainer: "TrainerBase" = base(model, - self._images, - self._args.batch_size, - self._args.configfile) + trainer: TrainerBase = base(model, + self._images, + self._args.batch_size, + self._args.configfile) logger.debug("Loaded Trainer") return trainer - def _run_training_cycle(self, model: "ModelBase", trainer: "TrainerBase") -> None: + def _run_training_cycle(self, model: ModelBase, trainer: TrainerBase) -> None: """ Perform the training cycle. Handles the background training, updating previews/time-lapse on each save interval, @@ -330,7 +326,7 @@ class Train(): # pylint:disable=too-few-public-methods logger.debug("Running Training Cycle") update_preview_images = False if self._args.write_image or self._args.redirect_gui or self._args.preview: - display_func: Optional[Callable] = self._show + display_func: Callable | None = self._show else: display_func = None @@ -411,7 +407,7 @@ class Train(): # pylint:disable=too-few-public-methods self._save_now = True return retval - def _process_gui_triggers(self) -> Dict[Literal["mask", "refresh"], bool]: + def _process_gui_triggers(self) -> dict[T.Literal["mask", "refresh"], bool]: """ Check whether a file drop has occurred from the GUI to manually update the preview. Returns @@ -419,7 +415,8 @@ class Train(): # pylint:disable=too-few-public-methods dict The trigger name as key and boolean as value """ - retval: Dict[Literal["mask", "refresh"], bool] = {key: False for key in self._gui_triggers} + retval: dict[T.Literal["mask", "refresh"], bool] = {key: False + for key in self._gui_triggers} if not self._args.redirect_gui: return retval @@ -527,11 +524,11 @@ class PreviewInterface(): """ def __init__(self, use_preview: bool) -> None: self._active = use_preview - self._triggers: TriggerType = dict(toggle_mask=Event(), - refresh=Event(), - save=Event(), - quit=Event(), - shutdown=Event()) + self._triggers: TriggerType = {"toggle_mask": Event(), + "refresh": Event(), + "save": Event(), + "quit": Event(), + "shutdown": Event()} self._buffer = PreviewBuffer() self._thread = self._launch_thread() @@ -596,7 +593,7 @@ class PreviewInterface(): logger.debug("Sending should stop") return retval - def _launch_thread(self) -> Optional[FSThread]: + def _launch_thread(self) -> FSThread | None: """ Launch the preview viewer in it's own thread if preview has been selected Returns @@ -609,7 +606,7 @@ class PreviewInterface(): thread = FSThread(target=Preview, name="preview", args=(self._buffer, ), - kwargs=dict(triggers=self._triggers)) + kwargs={"triggers": self._triggers}) thread.start() return thread diff --git a/setup.cfg b/setup.cfg index 7dc02608..6427fca9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,8 +49,6 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-tensorflow.*] ignore_missing_imports = True -[mypy-tensorflow_probability.*] -ignore_missing_imports = True [mypy-tqdm.*] ignore_missing_imports = True [mypy-win32console.*] diff --git a/setup.py b/setup.py index e711f0d8..d17f4c8a 100755 --- a/setup.py +++ b/setup.py @@ -11,43 +11,44 @@ import operator import os import re import sys +import typing as T from shutil import which from subprocess import list2cmdline, PIPE, Popen, run, STDOUT -from typing import Any, Dict, List, Optional, Set, Tuple, Type -from pkg_resources import parse_requirements, Requirement +from pkg_resources import parse_requirements from lib.logger import log_setup -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - - logger = logging.getLogger(__name__) # pylint: disable=invalid-name +backend_type: T.TypeAlias = T.Literal['nvidia', 'apple_silicon', 'directml', 'cpu', 'rocm'] + _INSTALL_FAILED = False +# Packages that are explicitly required for setup.py +_INSTALLER_REQUIREMENTS: list[tuple[str, str]] = [("pexpect>=4.8.0", "!Windows"), + ("pywinpty==2.0.2", "Windows")] +# Conda packages that are required for a specific backend +_BACKEND_SPECIFIC_CONDA: dict[backend_type, list[str]] = {"nvidia": ["cudatoolkit", "cudnn"], + "apple_silicon": ["libblas"]} +# Packages that should only be installed through pip +_FORCE_PIP: dict[backend_type, list[str]] = {"nvidia": ["tensorflow"]} # Revisions of tensorflow GPU and cuda/cudnn requirements. These relate specifically to the # Tensorflow builds available from pypi -_TENSORFLOW_REQUIREMENTS = {">=2.7.0,<2.11.0": ["11.2", "8.1"]} +_TENSORFLOW_REQUIREMENTS = {">=2.10.0,<2.11.0": [">=11.0,<12.0", ">=8.0,<9.0"]} # ROCm min/max version requirements for Tensorflow _TENSORFLOW_ROCM_REQUIREMENTS = {">=2.10.0,<2.11.0": ((5, 2, 0), (5, 4, 0))} # TODO tensorflow-metal versioning -# Packages that are explicitly required for setup.py -_INSTALLER_REQUIREMENTS: List[Tuple[str, str]] = [("pexpect>=4.8.0", "!Windows"), - ("pywinpty==2.0.2", "Windows")] # Mapping of Python packages to their conda names if different from pip or in non-default channel -_CONDA_MAPPING: Dict[str, Tuple[str, str]] = { - # "opencv-python": ("opencv", "conda-forge"), # Periodic issues with conda-forge opencv +_CONDA_MAPPING: dict[str, tuple[str, str]] = { "fastcluster": ("fastcluster", "conda-forge"), + "ffmpy": ("ffmpy", "conda-forge"), "imageio-ffmpeg": ("imageio-ffmpeg", "conda-forge"), - "scikit-learn": ("scikit-learn", "conda-forge"), # Exists in Default but is dependency hell + "nvidia-ml-py": ("nvidia-ml-py", "conda-forge"), "tensorflow-deps": ("tensorflow-deps", "apple"), "libblas": ("libblas", "conda-forge")} -# Packages that should be installed first to prevent version conflicts -_PRIORITY = ["numpy"] +# Force output to utf-8 +sys.stdout.reconfigure(encoding="utf-8", errors="replace") # type:ignore[attr-defined] class Environment(): @@ -66,11 +67,10 @@ class Environment(): self.updater = updater # Flag that setup is being run by installer so steps can be skipped self.is_installer: bool = False - self.backend: Optional[Literal["nvidia", "apple_silicon", - "directml", "cpu", "rocm"]] = None + self.backend: backend_type | None = None self.enable_docker: bool = False self.cuda_cudnn = ["", ""] - self.rocm_version: Tuple[int, ...] = (0, 0, 0) + self.rocm_version: tuple[int, ...] = (0, 0, 0) self._process_arguments() self._check_permission() @@ -88,12 +88,12 @@ class Environment(): return locale.getpreferredencoding() @property - def os_version(self) -> Tuple[str, str]: + def os_version(self) -> tuple[str, str]: """ Get OS Version """ return platform.system(), platform.release() @property - def py_version(self) -> Tuple[str, str]: + def py_version(self) -> tuple[str, str]: """ Get Python Version """ return platform.python_version(), platform.architecture()[0] @@ -181,8 +181,8 @@ class Environment(): if self.updater: return - if not ((3, 7) <= sys.version_info < (3, 10) and self.py_version[1] == "64bit"): - logger.error("Please run this script with Python version 3.7 to 3.9 64bit and try " + if not ((3, 10) <= sys.version_info < (3, 11) and self.py_version[1] == "64bit"): + logger.error("Please run this script with Python version 3.10 64bit and try " "again.") sys.exit(1) @@ -292,22 +292,16 @@ class Packages(): """ def __init__(self, environment: Environment) -> None: self._env = environment - self._conda_required_packages: List[Tuple[str, ...]] = [("tk", )] - if self._env.os_version[0] == "Linux": - # TODO Put these kind of dependencies somewhere more visible or remove when not needed - # conda-forge scipy requires GLIBCXX_3.4.30. Some Linux install do not have the - # specific version, so we install it just in case. - # Ref: https://forum.faceswap.dev/viewtopic.php?f=7&t=2247 - self._conda_required_packages.append(("gcc=12.1.0", "conda-forge")) - + self._conda_required_packages: list[tuple[str, ...]] = [("tk", ), ("git", )] + self._update_backend_specific_conda() self._installed_packages = self._get_installed_packages() self._conda_installed_packages = self._get_installed_conda_packages() - self._required_packages: List[Tuple[str, List[Tuple[str, str]]]] = [] - self._missing_packages: List[Tuple[str, List[Tuple[str, str]]]] = [] - self._conda_missing_packages: List[Tuple[str, ...]] = [] + self._required_packages: list[tuple[str, list[tuple[str, str]]]] = [] + self._missing_packages: list[tuple[str, list[tuple[str, str]]]] = [] + self._conda_missing_packages: list[tuple[str, ...]] = [] @property - def prerequisites(self) -> List[Tuple[str, List[Tuple[str, str]]]]: + def prerequisites(self) -> list[tuple[str, list[tuple[str, str]]]]: """ list: Any required packages that the installer needs prior to installing the faceswap environment on the specific platform that are not already installed """ all_installed = self._all_installed_packages @@ -328,23 +322,46 @@ class Packages(): return bool(self._missing_packages or self._conda_missing_packages) @property - def to_install(self) -> List[Tuple[str, List[Tuple[str, str]]]]: + def to_install(self) -> list[tuple[str, list[tuple[str, str]]]]: """ list: The required packages that need to be installed """ return self._missing_packages @property - def to_install_conda(self) -> List[Tuple[str, ...]]: + def to_install_conda(self) -> list[tuple[str, ...]]: """ list: The required conda packages that need to be installed """ return self._conda_missing_packages @property - def _all_installed_packages(self) -> Dict[str, str]: + def _all_installed_packages(self) -> dict[str, str]: """ dict[str, str]: The package names and version string for all installed packages across pip and conda """ return {**self._installed_packages, **self._conda_installed_packages} + def _update_backend_specific_conda(self) -> None: + """ Add backend specific packages to Conda required packages """ + assert self._env.backend is not None + to_add = _BACKEND_SPECIFIC_CONDA.get(self._env.backend) + if not to_add: + logger.debug("No backend packages to add for '%s'. All optional packages: %s", + self._env.backend, _BACKEND_SPECIFIC_CONDA) + return + for pkg in to_add: + pkg, channel = _CONDA_MAPPING.get(pkg, (pkg, "")) + if pkg in ("cudatoolkit", "cudnn"): # TODO Handle multiple cuda/cudnn requirements + idx = 0 if pkg == "cudatoolkit" else 1 + pkg = f"{pkg}{list(_TENSORFLOW_REQUIREMENTS.values())[0][idx]}" + if pkg.startswith("cudnn"): + # We add cudnn first so that dependency resolver does not need to re-download cuda + # if an incompatible version was installed + self._conda_required_packages.insert(0, (pkg, channel)) + else: + self._conda_required_packages.append((pkg, channel)) + logger.debug("Adding conda required package '%s' for backend '%s')", + pkg, self._env.backend) + @classmethod - def _format_requirements(cls, packages: List[str]) -> List[Tuple[str, List[Tuple[str, str]]]]: + def _format_requirements(cls, packages: list[str] + ) -> list[tuple[str, list[tuple[str, str]]]]: """ Parse a list of requirements.txt formatted package strings to a list of pkgresource formatted requirements """ return [(package.unsafe_name, package.specs) @@ -353,7 +370,7 @@ class Packages(): @classmethod def _validate_spec(cls, - required: List[Tuple[str, str]], + required: list[tuple[str, str]], existing: str) -> bool: """ Validate whether the required specification for a package is met by the installed version. @@ -377,7 +394,7 @@ class Packages(): [int(s) for s in spec[1].split(".")]) for spec in required) - def _get_installed_packages(self) -> Dict[str, str]: + def _get_installed_packages(self) -> dict[str, str]: """ Get currently installed packages and add to :attr:`_installed_packages` Returns @@ -398,7 +415,7 @@ class Packages(): logger.debug(installed_packages) return installed_packages - def _get_installed_conda_packages(self) -> Dict[str, str]: + def _get_installed_conda_packages(self) -> dict[str, str]: """ Get currently installed conda packages Returns @@ -439,31 +456,32 @@ class Packages(): if self._env.is_conda: # Conda handles Cuda and cuDNN so nothing to do here return tf_ver = None - cudnn_inst = self._env.cudnn_version.split(".") + cuda_inst = self._env.cuda_version + cudnn_inst = self._env.cudnn_version + if len(cudnn_inst) == 1: # Sometimes only major version is reported + cudnn_inst = f"{cudnn_inst}.0" for key, val in _TENSORFLOW_REQUIREMENTS.items(): - cuda_req = val[0] - cudnn_req = val[1].split(".") - if cuda_req == self._env.cuda_version and (cudnn_req[0] == cudnn_inst[0] and - cudnn_req[1] <= cudnn_inst[1]): + cuda_req = next(parse_requirements(f"cuda{val[0]}")).specs + cudnn_req = next(parse_requirements(f"cudnn{val[1]}")).specs + if (self._validate_spec(cuda_req, cuda_inst) + and self._validate_spec(cudnn_req, cudnn_inst)): tf_ver = key break + if tf_ver: # Remove the version of tensorflow in requirements file and add the correct version # that corresponds to the installed Cuda/cuDNN versions self._required_packages = [pkg for pkg in self._required_packages - if not pkg[0].startswith("tensorflow-gpu")] - tf_ver = f"tensorflow-gpu{tf_ver}" - - tf_ver = f"tensorflow-gpu{tf_ver}" - self._required_packages.append(("tensorflow-gpu", - next(parse_requirements(tf_ver)).specs)) + if pkg[0] != "tensorflow"] + tf_ver = f"tensorflow{tf_ver}" + self._required_packages.append(("tensorflow", next(parse_requirements(tf_ver)).specs)) return logger.warning( - "The minimum Tensorflow requirement is 2.8 \n" + "The minimum Tensorflow requirement is 2.10 \n" "Tensorflow currently has no official prebuild for your CUDA, cuDNN combination.\n" "Either install a combination that Tensorflow supports or build and install your own " - "tensorflow-gpu.\r\n" + "tensorflow.\r\n" "CUDA Version: %s\r\n" "cuDNN Version: %s\r\n" "Help:\n" @@ -472,8 +490,8 @@ class Packages(): "https://www.tensorflow.org/install/source#tested_build_configurations", self._env.cuda_version, self._env.cudnn_version) - custom_tf = input("Location of custom tensorflow-gpu wheel (leave " - "blank to manually install): ") + custom_tf = input("Location of custom tensorflow wheel (leave blank to manually " + "install): ") if not custom_tf: return @@ -529,14 +547,16 @@ class Packages(): if not self._env.is_conda: return for pkg in self._conda_required_packages: - key = pkg[0].split("==", maxsplit=1)[0] + reqs = next(parse_requirements(pkg[0])) # TODO Handle '=' vs '==' for conda + key = reqs.unsafe_name + specs = reqs.specs + if key not in self._conda_installed_packages: self._conda_missing_packages.append(pkg) continue - if len(pkg[0].split("==")) > 1: - if pkg[0].split("==")[1] != self._conda_installed_packages.get(key): - self._conda_missing_packages.append(pkg) - continue + + if not self._validate_spec(specs, self._conda_installed_packages[key]): + self._conda_missing_packages.append(pkg) logger.debug(self._conda_missing_packages) def check_missing_dependencies(self) -> None: @@ -554,12 +574,6 @@ class Packages(): if not self._validate_spec(specs, self._all_installed_packages.get(key, "")): self._missing_packages.append((key, specs)) - for priority in reversed(_PRIORITY): - # Put priority packages at beginning of list - package = next((pkg for pkg in self._missing_packages if pkg[0] == priority), None) - if package: - idx = self._missing_packages.index(package) - self._missing_packages.insert(0, self._missing_packages.pop(idx)) logger.debug(self._missing_packages) self._check_conda_missing_dependencies() @@ -732,7 +746,7 @@ class ROCmCheck(): # pylint:disable=too-few-public-methods def __init__(self) -> None: self.version_min = min(v[0] for v in _TENSORFLOW_ROCM_REQUIREMENTS.values()) self.version_max = max(v[1] for v in _TENSORFLOW_ROCM_REQUIREMENTS.values()) - self.rocm_version: Tuple[int, ...] = (0, 0, 0) + self.rocm_version: tuple[int, ...] = (0, 0, 0) if platform.system() == "Linux": self._rocm_check() @@ -771,15 +785,15 @@ class CudaCheck(): # pylint:disable=too-few-public-methods """ Find the location of system installed Cuda and cuDNN on Windows and Linux. """ def __init__(self) -> None: - self.cuda_path: Optional[str] = None - self.cuda_version: Optional[str] = None - self.cudnn_version: Optional[str] = None + self.cuda_path: str | None = None + self.cuda_version: str | None = None + self.cudnn_version: str | None = None self._os: str = platform.system().lower() - self._cuda_keys: List[str] = [key + self._cuda_keys: list[str] = [key for key in os.environ if key.lower().startswith("cuda_path_v")] - self._cudnn_header_files: List[str] = ["cudnn_version.h", "cudnn.h"] + self._cudnn_header_files: list[str] = ["cudnn_version.h", "cudnn.h"] logger.debug("cuda keys: %s, cudnn header files: %s", self._cuda_keys, self._cudnn_header_files) if self._os in ("windows", "linux"): @@ -839,13 +853,14 @@ class CudaCheck(): # pylint:disable=too-few-public-methods self.cuda_version = self._cuda_keys[0].lower().replace("cuda_path_v", "").replace("_", ".") self.cuda_path = os.environ[self._cuda_keys[0][0]] - def _cudnn_check(self): - """ Check Linux or Windows cuDNN Version from cudnn.h and add to :attr:`cudnn_version`. """ + def _cudnn_check_files(self) -> bool: + """ Check header files for cuDNN version """ cudnn_checkfiles = getattr(self, f"_get_checkfiles_{self._os}")() cudnn_checkfile = next((hdr for hdr in cudnn_checkfiles if os.path.isfile(hdr)), None) logger.debug("cudnn checkfiles: %s", cudnn_checkfile) if not cudnn_checkfile: - return + return False + found = 0 with open(cudnn_checkfile, "r", encoding="utf8") as ofile: for line in ofile: @@ -860,12 +875,31 @@ class CudaCheck(): # pylint:disable=too-few-public-methods found += 1 if found == 3: break - if found != 3: # Full version could not be determined - return + if found != 3: # Full version not determined + return False + self.cudnn_version = ".".join([str(major), str(minor), str(patchlevel)]) logger.debug("cudnn version: %s", self.cudnn_version) + return True - def _get_checkfiles_linux(self) -> List[str]: + def _cudnn_check(self) -> None: + """ Check Linux or Windows cuDNN Version from cudnn.h and add to :attr:`cudnn_version`. """ + if self._cudnn_check_files(): + return + if self._os == "windows": + return + + chk = os.popen("ldconfig -p | grep -P \"libcudnn.so.\" | head -n 1").read() + if not chk: + return + cudnnvers = chk.strip().replace("libcudnn.so.", "").split()[0] + if not cudnnvers: + return + + self.cudnn_version = cudnnvers + logger.debug("cudnn version: %s", self.cudnn_version) + + def _get_checkfiles_linux(self) -> list[str]: """ Return the the files to check for cuDNN locations for Linux by querying the dynamic link loader. @@ -887,7 +921,7 @@ class CudaCheck(): # pylint:disable=too-few-public-methods cudnn_checkfiles = [os.path.join(cudnn_path, header) for header in header_files] return cudnn_checkfiles - def _get_checkfiles_windows(self) -> List[str]: + def _get_checkfiles_windows(self) -> list[str]: """ Return the check-file locations for Windows. Just looks inside the include folder of the discovered :attr:`cuda_path` @@ -921,7 +955,7 @@ class Install(): # pylint:disable=too-few-public-methods self._is_gui = is_gui if self._env.os_version[0] == "Windows": - self._installer: Type[Installer] = WinPTYInstaller + self._installer: type[Installer] = WinPTYInstaller else: self._installer = PexpectInstaller @@ -964,7 +998,7 @@ class Install(): # pylint:disable=too-few-public-methods sys.exit(1) @classmethod - def _format_package(cls, package: str, version: List[Tuple[str, str]]) -> str: + def _format_package(cls, package: str, version: list[tuple[str, str]]) -> str: """ Format a parsed requirement package and version string to a format that can be used by the installer. @@ -1006,51 +1040,37 @@ class Install(): # pylint:disable=too-few-public-methods logger.error("Unable to install package: %s. Process aborted", clean_pkg) sys.exit(1) + def _install_conda_packages(self) -> None: + """ Install required conda packages """ + logger.info("Installing Required Conda Packages. This may take some time...") + for pkg in self._packages.to_install_conda: + channel = "" if len(pkg) != 2 else pkg[1] + self._from_conda(pkg[0], channel=channel, conda_only=True) + + def _install_python_packages(self) -> None: + """ Install required pip packages """ + conda_only = False + assert self._env.backend is not None + for pkg, version in self._packages.to_install: + if self._env.is_conda: + mapping = _CONDA_MAPPING.get(pkg, (pkg, "")) + channel = "" if mapping[1] is None else mapping[1] + pkg = mapping[0] + pip_only = pkg in _FORCE_PIP.get(self._env.backend, []) + pkg = self._format_package(pkg, version) if version else pkg + if self._env.is_conda and not pip_only: + if self._from_conda(pkg, channel=channel, conda_only=conda_only): + continue + self._from_pip(pkg) + def _install_missing_dep(self) -> None: """ Install missing dependencies """ self._install_conda_packages() # Install conda packages first self._install_python_packages() - def _install_python_packages(self) -> None: - """ Install required pip packages """ - conda_only = False - for pkg, version in self._packages.to_install: - if self._env.is_conda: - mapping = _CONDA_MAPPING.get(pkg, (pkg, "")) - channel = None if mapping[1] == "" else mapping[1] - pkg = mapping[0] - pkg = self._format_package(pkg, version) if version else pkg - if self._env.is_conda: - if pkg.startswith("tensorflow-gpu"): - # From TF 2.4 onwards, Anaconda Tensorflow becomes a mess. The version of 2.5 - # installed by Anaconda is compiled against an incorrect numpy version which - # breaks Tensorflow. Coupled with this the versions of cudatoolkit and cudnn - # available in the default Anaconda channel are not compatible with the - # official PyPi versions of Tensorflow. With this in mind we will pull in the - # required Cuda/cuDNN from conda-forge, and install Tensorflow with pip - # TODO Revert to Conda if they get their act together - - # Rewrite tensorflow requirement to versions from highest available cuda/cudnn - highest_cuda = sorted(_TENSORFLOW_REQUIREMENTS.values())[-1] - compat_tf = next(k for k, v in _TENSORFLOW_REQUIREMENTS.items() - if v == highest_cuda) - pkg = f"tensorflow-gpu{compat_tf}" - conda_only = True - - if self._from_conda(pkg, channel=channel, conda_only=conda_only): - continue - self._from_pip(pkg) - - def _install_conda_packages(self) -> None: - """ Install required conda packages """ - logger.info("Installing Required Conda Packages. This may take some time...") - for pkg in self._packages.to_install_conda: - channel = None if len(pkg) != 2 else pkg[1] - self._from_conda(pkg[0], channel=channel, conda_only=True) - def _from_conda(self, package: str, - channel: Optional[str] = None, + channel: str = "", conda_only: bool = False) -> bool: """ Install a conda package @@ -1059,8 +1079,8 @@ class Install(): # pylint:disable=too-few-public-methods package: str The full formatted package, with version, to be installed channel: str, optional - The Conda channel to install from. Select ``None`` for default channel. - Default: ``None`` + The Conda channel to install from. Select empty string for default channel. + Default: ``""`` (empty string) conda_only: bool, optional ``True`` if the package is only available in Conda. Default: ``False`` @@ -1075,23 +1095,9 @@ class Install(): # pylint:disable=too-few-public-methods if channel: condaexe.extend(["-c", channel]) - if package.startswith("tensorflow-gpu"): - # Here we will install the cuda/cudnn toolkits, currently only available from - # conda-forge, but fail tensorflow itself so that it can be handled by pip. - specs = Requirement.parse(package).specs - for key, val in _TENSORFLOW_REQUIREMENTS.items(): - req_specs = Requirement.parse("foobar" + key).specs - if all(item in req_specs for item in specs): - cuda, cudnn = val - break - condaexe.extend(["-c", "conda-forge", f"cudatoolkit={cuda}", f"cudnn={cudnn}"]) - package = "Cuda Toolkit" - success = False - - if package != "Cuda Toolkit": - if any(char in package for char in (" ", "<", ">", "*", "|")): - package = f"\"{package}\"" - condaexe.append(package) + if any(char in package for char in (" ", "<", ">", "*", "|")): + package = f"\"{package}\"" + condaexe.append(package) clean_pkg = package.replace("\"", "") installer = self._installer(self._env, clean_pkg, condaexe, self._is_gui) @@ -1127,6 +1133,104 @@ class Install(): # pylint:disable=too-few-public-methods _INSTALL_FAILED = True +class ProgressBar(): + """ Simple progress bar using STDLib for intercepting Conda installs and keeping the + terminal from getting jumbled """ + def __init__(self): + self._width_desc = 21 + self._width_size = 9 + self._width_bar = 37 + self._width_pct = 4 + self._marker = "█" + + self._cursor_visible = True + self._current_pos = 0 + self._bars = [] + + @classmethod + def _display_cursor(cls, visible: bool) -> None: + """ Sends ANSI code to display or hide the cursor + + Parameters + ---------- + visible: bool + ``True`` to display the cursor. ``False`` to hide the cursor + """ + code = "\x1b[?25h" if visible else "\x1b[?25l" + print(code, end="\r") + + def _format_bar(self, description: str, size: str, percent: int) -> str: + """ Format the progress bar for display + + Parameters + ---------- + description: str + The description to display for the progress bar + size: str + The size of the download, including units + percent: int + The percentage progress of the bar + """ + size = size[:self._width_size].ljust(self._width_size) + bar_len = int(self._width_bar * (percent / 100)) + progress = f"{self._marker * bar_len}"[:self._width_bar].ljust(self._width_bar) + pct = f"{percent}%"[:self._width_pct].rjust(self._width_pct) + return f" {description}| {size} | {progress} | {pct}" + + def _move_cursor(self, position: int) -> str: + """ Generate ANSI code for moving the cursor to the given progress bar's position + + Parameters + ---------- + position: int + The progress bar position to move to + + Returns + ------- + str + The ansi code to move to the given position + """ + move = position - self._current_pos + retval = "\x1b[A" if move < 0 else "\x1b[B" if move > 0 else "" + retval *= abs(move) + return retval + + def __call__(self, description: str, size: str, percent: int) -> None: + """ Create or update a progress bar + + Parameters + ---------- + description: str + The description to display for the progress bar + size: str + The size of the download, including units + percent: int + The percentage progress of the bar + """ + if self._cursor_visible: + self._display_cursor(visible=False) + + desc = description[:self._width_desc].ljust(self._width_desc) + if desc not in self._bars: + self._bars.append(desc) + + position = self._bars.index(desc) + pbar = self._format_bar(desc, size, percent) + + output = f"{self._move_cursor(position)} {pbar}" + + print(output) + self._current_pos = position + 1 + + def close(self) -> None: + """ Reset all progress bars and re-enable the cursor """ + print(self._move_cursor(len(self._bars)), end="\r") + self._display_cursor(True) + self._cursor_visible = True + self._current_pos = 0 + self._bars = [] + + class Installer(): """ Parent class for package installers. @@ -1150,16 +1254,23 @@ class Installer(): def __init__(self, environment: Environment, package: str, - command: List[str], + command: list[str], is_gui: bool) -> None: logger.info("Installing %s", package) logger.debug("argv: %s", command) self._env = environment self._package = package self._command = command + self._is_conda = "conda" in command self._is_gui = is_gui - self._last_line_cr = False - self._seen_lines: Set[str] = set() + + self._progess_bar = ProgressBar() + self._re_conda = re.compile( + rb"(?P^\S+)\s+\|\s+(?P\d+\.?\d*\s\w+).*\|\s+(?P\d+%)") + self._re_pip_pkg = re.compile(rb"^\s*Downloading\s(?P\w+-.+?)-") + self._re_pip = re.compile(rb"(?P\d+\.?\d*)/(?P\d+\.?\d*\s\w+)") + self._pip_pkg = "" + self._seen_lines: set[str] = set() def __call__(self) -> int: """ Call the subclassed call function @@ -1174,9 +1285,11 @@ class Installer(): except Exception as err: # pylint:disable=broad-except logger.debug("Failed to install with %s. Falling back to subprocess. Error: %s", self.__class__.__name__, str(err)) + self._progess_bar.close() returncode = SubProcInstaller(self._env, self._package, self._command, self._is_gui)() logger.debug("Package: %s, returncode: %s", self._package, returncode) + self._progess_bar.close() return returncode def call(self) -> int: @@ -1189,19 +1302,57 @@ class Installer(): """ raise NotImplementedError() - def _non_gui_print(self, text: str, end: Optional[str] = None) -> None: + def _print_conda(self, text: bytes) -> None: + """ Output progress for Conda installs + + Parameters + ---------- + text: bytes + The text to print + """ + data = self._re_conda.match(text) + if not data: + return + lib = data.groupdict()["lib"].decode("utf-8", errors="replace") + size = data.groupdict()["tot"].decode("utf-8", errors="replace") + progress = int(data.groupdict()["prg"].decode("utf-8", errors="replace")[:-1]) + self._progess_bar(lib, size, progress) + + def _print_pip(self, text: bytes) -> None: + """ Output progress for Pip installs + + Parameters + ---------- + text: bytes + The text to print + """ + pkg = self._re_pip_pkg.match(text) + if pkg: + logger.debug("Collected pip package '%s'", pkg) + self._pip_pkg = pkg.groupdict()["lib"].decode("utf-8", errors="replace") + return + data = self._re_pip.search(text) + if not data: + return + done = float(data.groupdict()["done"].decode("utf-8", errors="replace")) + size = data.groupdict()["tot"].decode("utf-8", errors="replace") + progress = int(round(done / float(size.split()[0]) * 100, 0)) + self._progess_bar(self._pip_pkg, size, progress) + + def _non_gui_print(self, text: bytes) -> None: """ Print output to console if not running in the GUI Parameters ---------- - text: str + text: bytes The text to print - end: str, optional - The line ending to use. Default: ``None`` (new line) """ if self._is_gui: return - print(text, end=end) + if self._is_conda: + self._print_conda(text) + else: + self._print_pip(text) def _seen_line_log(self, text: str) -> None: """ Output gets spammed to the log file when conda is waiting/processing. Only log each @@ -1214,7 +1365,7 @@ class Installer(): """ if text in self._seen_lines: return - logger.verbose(text) # type:ignore + logger.debug(text) self._seen_lines.add(text) @@ -1243,22 +1394,13 @@ class PexpectInstaller(Installer): # pylint: disable=too-few-public-methods The return code of the package install process """ import pexpect # pylint:disable=import-outside-toplevel,import-error - proc = pexpect.spawn(" ".join(self._command), - encoding=self._env.encoding, codec_errors="replace", timeout=None) + proc = pexpect.spawn(" ".join(self._command), timeout=None) while True: try: - idx = proc.expect(["\r\n", "\r"]) - line = proc.before.rstrip() - if line and idx == 0: - if self._last_line_cr: - self._last_line_cr = False - # Output last line of progress bar and go to next line - self._non_gui_print(line) - self._seen_line_log(line) - elif line and idx == 1: - self._last_line_cr = True - logger.debug(line) - self._non_gui_print(line, end="\r") + proc.expect([b"\r\n", b"\r"]) + line: bytes = proc.before + self._seen_line_log(line.decode("utf-8", errors="replace").rstrip()) + self._non_gui_print(line) except pexpect.EOF: break proc.close() @@ -1284,7 +1426,7 @@ class WinPTYInstaller(Installer): # pylint: disable=too-few-public-methods def __init__(self, environment: Environment, package: str, - command: List[str], + command: list[str], is_gui: bool) -> None: super().__init__(environment, package, command, is_gui) self._cmd = which(command[0], path=os.environ.get('PATH', os.defpath)) @@ -1295,10 +1437,10 @@ class WinPTYInstaller(Installer): # pylint: disable=too-few-public-methods self._eof = False self._read_bytes = 1024 - self._lines: List[str] = [] + self._lines: list[str] = [] self._out = "" - def _read_from_pty(self, proc: Any, winpty_error: Any) -> None: + def _read_from_pty(self, proc: T.Any, winpty_error: T.Any) -> None: """ Read :attr:`_num_bytes` from WinPTY. If there is an error reading, recursively halve the number of bytes read until we get a succesful read. If we get down to 1 byte without a succesful read, assume we are at EOF. @@ -1350,25 +1492,6 @@ class WinPTYInstaller(Installer): # pylint: disable=too-few-public-methods self._out = self._lines[-1] self._lines = self._lines[:-1] - def _parse_lines(self) -> None: - """ Process the latest batch of lines that have been received from winPTY. """ - for line in self._lines: # Dump the output to log - line = line.rstrip() - is_cr = bool(self._pbar.search(line)) - if line and not is_cr: - if self._last_line_cr: - self._last_line_cr = False - if not self._env.is_installer: - # Go to next line - self._non_gui_print("") - self._seen_line_log(line) - elif line: - self._last_line_cr = True - logger.debug(line) - # NSIS only updates on line endings, so force new line for installer - self._non_gui_print(line, end=None if self._env.is_installer else "\r") - self._lines = [] - def call(self) -> int: """ Install a package using the PyWinPTY module @@ -1380,7 +1503,7 @@ class WinPTYInstaller(Installer): # pylint: disable=too-few-public-methods import winpty # pylint:disable=import-outside-toplevel,import-error # For some reason with WinPTY we need to pass in the full command. Probably a bug proc = winpty.PTY( - 80 if self._env.is_installer else 100, + 100, 24, backend=winpty.enums.Backend.WinPTY, # ConPTY hangs and has lots of Ansi Escapes agent_config=winpty.enums.AgentConfig.WINPTY_FLAG_PLAIN_OUTPUT) # Strip all Ansi @@ -1392,7 +1515,10 @@ class WinPTYInstaller(Installer): # pylint: disable=too-few-public-methods while True: self._read_from_pty(proc, winpty.WinptyError) self._out_to_lines() - self._parse_lines() + for line in self._lines: + self._seen_line_log(line.rstrip()) + self._non_gui_print(line.encode("utf-8", errors="replace")) + self._lines = [] if self._eof: returncode = proc.get_exitstatus() @@ -1422,7 +1548,7 @@ class SubProcInstaller(Installer): def __init__(self, environment: Environment, package: str, - command: List[str], + command: list[str], is_gui: bool) -> None: super().__init__(environment, package, command, is_gui) self._shell = self._env.os_version[0] == "Windows" and command[0] == "conda" @@ -1445,24 +1571,15 @@ class SubProcInstaller(Installer): bufsize=0, stdout=PIPE, stderr=STDOUT, shell=self._shell) as proc: while True: if proc.stdout is not None: - line = proc.stdout.readline().decode(self._env.encoding, errors="replace") + lines = proc.stdout.readline() returncode = proc.poll() - if line == "" and returncode is not None: + if lines == b"" and returncode is not None: break - is_cr = line.startswith("\r") - line = line.rstrip() + for line in lines.split(b"\r"): + self._seen_line_log(line.decode("utf-8", errors="replace").rstrip()) + self._non_gui_print(line) - if line and not is_cr: - if self._last_line_cr: - self._last_line_cr = False - # Go to next line - self._non_gui_print("") - self._seen_line_log(line) - elif line: - self._last_line_cr = True - logger.debug(line) - self._non_gui_print("", end="\r") return returncode @@ -1471,74 +1588,45 @@ class Tips(): @classmethod def docker_no_cuda(cls) -> None: """ Output Tips for Docker without Cuda """ - path = os.path.dirname(os.path.realpath(__file__)) logger.info( - "1. Install Docker\n" - "https://www.docker.com/community-edition\n\n" - "2. Build Docker Image For Faceswap\n" - "docker build -t deepfakes-cpu -f Dockerfile.cpu .\n\n" - "3. Mount faceswap volume and Run it\n" - "# without GUI\n" - "docker run -tid -p 8888:8888 \\ \n" - "\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n" - "\t-v %s:/srv \\ \n" - "\tdeepfakes-cpu\n\n" - "# with gui. tools.py gui working.\n" - "## enable local access to X11 server\n" - "xhost +local:\n" - "## create container\n" - "nvidia-docker run -tid -p 8888:8888 \\ \n" - "\t--hostname deepfakes-cpu --name deepfakes-cpu \\ \n" - "\t-v %s:/srv \\ \n" - "\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" - "\t-e DISPLAY=unix$DISPLAY \\ \n" - "\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n" - "\t-e VIDEO_GID=`getent group video | cut -d: -f3` \\ \n" - "\t-e GID=`id -g` \\ \n" - "\t-e UID=`id -u` \\ \n" - "\tdeepfakes-cpu \n\n" - "4. Open a new terminal to run faceswap.py in /srv\n" - "docker exec -it deepfakes-cpu bash", path, path) - logger.info("That's all you need to do with a docker. Have fun.") + "1. Install Docker from: https://www.docker.com/get-started\n\n" + "2. Enter the Faceswap folder and build the Docker Image For Faceswap:\n" + " docker build -t faceswap-cpu -f Dockerfile.cpu .\n\n" + "3. Launch and enter the Faceswap container:\n" + " a. Headless:\n" + " docker run --rm -it -v ./:/srv faceswap-cpu\n\n" + " b. GUI:\n" + " xhost +local: && \\ \n" + " docker run --rm -it \\ \n" + " -v ./:/srv \\ \n" + " -v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" + " -e DISPLAY=${DISPLAY} \\ \n" + " faceswap-cpu \n") + logger.info("That's all you need to do with docker. Have fun.") @classmethod def docker_cuda(cls) -> None: """ Output Tips for Docker with Cuda""" - path = os.path.dirname(os.path.realpath(__file__)) logger.info( - "1. Install Docker\n" - "https://www.docker.com/community-edition\n\n" - "2. Install latest CUDA\n" - "CUDA: https://developer.nvidia.com/cuda-downloads\n\n" - "3. Install Nvidia-Docker & Restart Docker Service\n" - "https://github.com/NVIDIA/nvidia-docker\n\n" - "4. Build Docker Image For Faceswap\n" - "docker build -t deepfakes-gpu -f Dockerfile.gpu .\n\n" - "5. Mount faceswap volume and Run it\n" - "# without gui \n" - "docker run -tid -p 8888:8888 \\ \n" - "\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n" - "\t-v %s:/srv \\ \n" - "\tdeepfakes-gpu\n\n" - "# with gui.\n" - "## enable local access to X11 server\n" - "xhost +local:\n" - "## enable nvidia device if working under bumblebee\n" - "echo ON > /proc/acpi/bbswitch\n" - "## create container\n" - "nvidia-docker run -tid -p 8888:8888 \\ \n" - "\t--hostname deepfakes-gpu --name deepfakes-gpu \\ \n" - "\t-v %s:/srv \\ \n" - "\t-v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" - "\t-e DISPLAY=unix$DISPLAY \\ \n" - "\t-e AUDIO_GID=`getent group audio | cut -d: -f3` \\ \n" - "\t-e VIDEO_GID=`getent group video | cut -d: -f3` \\ \n" - "\t-e GID=`id -g` \\ \n" - "\t-e UID=`id -u` \\ \n" - "\tdeepfakes-gpu\n\n" - "6. Open a new terminal to interact with the project\n" - "docker exec deepfakes-gpu python /srv/faceswap.py gui\n", - path, path) + "1. Install Docker from: https://www.docker.com/get-started\n\n" + "2. Install latest CUDA 11 and cuDNN 8 from: https://developer.nvidia.com/cuda-" + "downloads\n\n" + "3. Install the the Nvidia Container Toolkit from https://docs.nvidia.com/datacenter/" + "cloud-native/container-toolkit/latest/install-guide\n\n" + "4. Restart Docker Service\n\n" + "5. Enter the Faceswap folder and build the Docker Image For Faceswap:\n" + " docker build -t faceswap-gpu -f Dockerfile.gpu .\n\n" + "6. Launch and enter the Faceswap container:\n" + " a. Headless:\n" + " docker run --runtime=nvidia --rm -it -v ./:/srv faceswap-gpu\n\n" + " b. GUI:\n" + " xhost +local: && \\ \n" + " docker run --runtime=nvidia --rm -it \\ \n" + " -v ./:/srv \\ \n" + " -v /tmp/.X11-unix:/tmp/.X11-unix \\ \n" + " -e DISPLAY=${DISPLAY} \\ \n" + " faceswap-gpu \n") + logger.info("That's all you need to do with docker. Have fun.") @classmethod def macos(cls) -> None: diff --git a/tests/lib/gpu_stats/_base_test.py b/tests/lib/gpu_stats/_base_test.py index 97485f51..225d11f4 100644 --- a/tests/lib/gpu_stats/_base_test.py +++ b/tests/lib/gpu_stats/_base_test.py @@ -1,7 +1,8 @@ #!/usr/bin python3 """ Pytest unit tests for :mod:`lib.gpu_stats._base` """ +import typing as T + from dataclasses import dataclass -from typing import cast from unittest.mock import MagicMock import pytest @@ -71,8 +72,8 @@ def test__gpu_stats_init_(gpu_stats_instance: _GPUStats) -> None: """ # Ensure that the object is initialized and shutdown correctly assert gpu_stats_instance._is_initialized is False - assert cast(MagicMock, gpu_stats_instance._initialize).call_count == 1 - assert cast(MagicMock, gpu_stats_instance._shutdown).call_count == 1 + assert T.cast(MagicMock, gpu_stats_instance._initialize).call_count == 1 + assert T.cast(MagicMock, gpu_stats_instance._shutdown).call_count == 1 # Ensure that the object correctly gets and stores the device count, active devices, # handles, driver, device names, and VRAM information diff --git a/tests/lib/gui/stats/event_reader_test.py b/tests/lib/gui/stats/event_reader_test.py index 618d1e22..21679055 100644 --- a/tests/lib/gui/stats/event_reader_test.py +++ b/tests/lib/gui/stats/event_reader_test.py @@ -1,13 +1,13 @@ #!/usr/bin python3 """ Pytest unit tests for :mod:`lib.gui.stats.event_reader` """ # pylint:disable=protected-access - +from __future__ import annotations import json import os +import typing as T from shutil import rmtree from time import time -from typing import cast, Iterator from unittest.mock import MagicMock import numpy as np @@ -20,6 +20,9 @@ from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module from lib.gui.analysis.event_reader import (_Cache, _CacheData, _EventParser, _LogFiles, EventData, TensorBoardLogs) +if T.TYPE_CHECKING: + from collections.abc import Iterator + def test__logfiles(tmp_path: str): """ Test the _LogFiles class operates correctly @@ -627,9 +630,9 @@ class Test_EventParser: # pylint:disable=invalid-name monkeypatch.setattr("lib.utils._FS_BACKEND", "cpu") event_parse = event_parser_instance - event_parse._parse_outputs = cast(MagicMock, mocker.MagicMock()) # type:ignore - event_parse._process_event = cast(MagicMock, mocker.MagicMock()) # type:ignore - event_parse._cache.cache_data = cast(MagicMock, mocker.MagicMock()) # type:ignore + event_parse._parse_outputs = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + event_parse._process_event = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + event_parse._cache.cache_data = T.cast(MagicMock, mocker.MagicMock()) # type:ignore # keras model monkeypatch.setattr(event_parse, diff --git a/tests/lib/model/optimizers_test.py b/tests/lib/model/optimizers_test.py index 3a34a605..34d53358 100644 --- a/tests/lib/model/optimizers_test.py +++ b/tests/lib/model/optimizers_test.py @@ -70,13 +70,6 @@ def _test_optimizer(optimizer, target=0.75): assert_allclose(bias, 2.) -@pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()]) -def test_adam(dummy): # pylint:disable=unused-argument - """ Test for custom Adam optimizer """ - _test_optimizer(k_optimizers.Adam(), target=0.45) # pylint:disable=no-member - _test_optimizer(k_optimizers.Adam(decay=1e-3), target=0.45) # pylint:disable=no-member - - @pytest.mark.parametrize("dummy", [None], ids=[get_backend().upper()]) def test_adabelief(dummy): # pylint:disable=unused-argument """ Test for custom Adam optimizer """ diff --git a/tests/lib/sysinfo_test.py b/tests/lib/sysinfo_test.py index 5cdce977..215e8c9f 100644 --- a/tests/lib/sysinfo_test.py +++ b/tests/lib/sysinfo_test.py @@ -5,10 +5,10 @@ import locale import os import platform import sys +import typing as T from collections import namedtuple from io import StringIO -from typing import cast from unittest.mock import MagicMock import pytest @@ -258,8 +258,8 @@ def test__configs__parse_configs(configs_instance: _Configs, """ assert hasattr(configs_instance, "_parse_configs") assert isinstance(configs_instance._parse_configs([]), str) - configs_instance._parse_ini = cast(MagicMock, mocker.MagicMock()) # type:ignore - configs_instance._parse_json = cast(MagicMock, mocker.MagicMock()) # type:ignore + configs_instance._parse_ini = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + configs_instance._parse_json = T.cast(MagicMock, mocker.MagicMock()) # type:ignore configs_instance._parse_configs(config_files=["test.ini", ".faceswap"]) assert configs_instance._parse_ini.called assert configs_instance._parse_json.called diff --git a/tests/lib/utils_test.py b/tests/lib/utils_test.py index 092c3a5f..34f9be5f 100644 --- a/tests/lib/utils_test.py +++ b/tests/lib/utils_test.py @@ -1,14 +1,15 @@ #!/usr/bin python3 """ Pytest unit tests for :mod:`lib.utils` """ import os +import platform import time +import typing as T import warnings import zipfile from io import StringIO from socket import timeout as socket_timeout, error as socket_error from shutil import rmtree -from typing import Any, cast, List, Tuple, Union from unittest.mock import MagicMock from urllib import error as urlliberror @@ -160,7 +161,7 @@ _PARAMS = [("/path/to/file.txt", ["/", "path", "to", "file.txt"]), # Absolute @pytest.mark.parametrize("path,result", _PARAMS, ids=[f'"{p[0]}"' for p in _PARAMS]) -def test_full_path_split(path: str, result: List[str]) -> None: +def test_full_path_split(path: str, result: list[str]) -> None: """ Test the :func:`~lib.utils.full_path_split` function works correctly Parameters @@ -188,7 +189,7 @@ _PARAMS = [("camelCase", ["camel", "Case"]), @pytest.mark.parametrize("text, result", _PARAMS, ids=[f'"{p[0]}"' for p in _PARAMS]) -def test_camel_case_split(text: str, result: List[str]) -> None: +def test_camel_case_split(text: str, result: list[str]) -> None: """ Test the :func:`~lib.utils.camel_case_spli` function works correctly Parameters @@ -207,7 +208,7 @@ def test_camel_case_split(text: str, result: List[str]) -> None: def test_get_tf_version() -> None: """ Test the :func:`~lib.utils.get_tf_version` function version returns correctly in range """ tf_version = get_tf_version() - assert (2, 2) <= tf_version < (2, 11) + assert (2, 10) <= tf_version < (2, 11) def test_get_dpi() -> None: @@ -235,7 +236,7 @@ _SECPARAMS = [((1, ), 1), # 1 argument @pytest.mark.parametrize("args,result", _SECPARAMS, ids=[str(p[0]) for p in _SECPARAMS]) -def test_convert_to_secs(args: Tuple[int, ...], result: int) -> None: +def test_convert_to_secs(args: tuple[int, ...], result: int) -> None: """ Test the :func:`~lib.utils.convert_to_secs` function works correctly Parameters @@ -360,8 +361,8 @@ _EXPECTED = ((["test_model_file_v3.h5"], "test_model_file_v3", "test_model_file" @pytest.mark.parametrize("filename,results", zip(_INPUT, _EXPECTED), ids=[str(i) for i in _INPUT]) def test_get_model_model_filename_input( get_model_instance: GetModel, # pylint:disable=unused-argument - filename: Union[str, List[str]], - results: Union[str, List[str]]) -> None: + filename: str | list[str], + results: str | list[str]) -> None: """ Test :class:`~lib.utils.GetModel` filename parsing works Parameters @@ -430,8 +431,8 @@ def test_get_model__get(mocker: pytest_mock.MockerFixture, For testing the function when a model exists and when it does not """ model = get_model_instance - model._download_model = cast(MagicMock, mocker.MagicMock()) # type:ignore - model._unzip_model = cast(MagicMock, mocker.MagicMock()) # type:ignore + model._download_model = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + model._unzip_model = T.cast(MagicMock, mocker.MagicMock()) # type:ignore os_remove = mocker.patch("os.remove") if model_exists: # Dummy in a model file @@ -459,8 +460,8 @@ _DLPARAMS = [(None, None), @pytest.mark.parametrize("error_type,error_args", _DLPARAMS, ids=[str(p[0]) for p in _DLPARAMS]) def test_get_model__download_model(mocker: pytest_mock.MockerFixture, get_model_instance: GetModel, - error_type: Any, - error_args: Tuple[Union[str, int], ...]) -> None: + error_type: T.Any, + error_args: tuple[str | int, ...]) -> None: """ Test :func:`~lib.utils.GetModel._download_model` executes its logic correctly Parameters @@ -476,7 +477,7 @@ def test_get_model__download_model(mocker: pytest_mock.MockerFixture, """ mock_urlopen = mocker.patch("urllib.request.urlopen") if not error_type: # Model download is successful - get_model_instance._write_zipfile = cast(MagicMock, mocker.MagicMock()) # type:ignore + get_model_instance._write_zipfile = T.cast(MagicMock, mocker.MagicMock()) # type:ignore get_model_instance._download_model() assert mock_urlopen.called assert get_model_instance._write_zipfile.called @@ -609,11 +610,13 @@ def test_debug_times(): assert len(debug_times._times["Test2"]) == 1 # Ensure that the summary method includes the correct min, mean, and max times for each step - assert min(debug_times._times["Test1"]) == pytest.approx(0.1, abs=1e-1) - assert min(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1) - assert max(debug_times._times["Test1"]) == pytest.approx(0.1, abs=1e-1) - assert max(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1) + # Github workflow for macos-latest can swing out a fair way + threshold = 2e-1 if platform.system() == "Darwin" else 1e-1 + assert min(debug_times._times["Test1"]) == pytest.approx(0.1, abs=threshold) + assert min(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold) + assert max(debug_times._times["Test1"]) == pytest.approx(0.1, abs=threshold) + assert max(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold) assert (sum(debug_times._times["Test1"]) / - len(debug_times._times["Test1"])) == pytest.approx(0.1, abs=1e-1) + len(debug_times._times["Test1"])) == pytest.approx(0.1, abs=threshold) assert (sum(debug_times._times["Test2"]) / - len(debug_times._times["Test2"]) == pytest.approx(0.2, abs=1e-1)) + len(debug_times._times["Test2"]) == pytest.approx(0.2, abs=threshold)) diff --git a/tests/simple_tests.py b/tests/simple_tests.py index 2b3be208..92237f68 100644 --- a/tests/simple_tests.py +++ b/tests/simple_tests.py @@ -108,11 +108,10 @@ def convert_args(in_path, out_path, model_path, writer, args=None): return conv_args.split() # Don't use pathes with spaces ;) -def sort_args(in_path, out_path, sortby="face", groupby="hist", method="rename"): +def sort_args(in_path, out_path, sortby="face", groupby="hist"): """ Sort command """ py_exe = sys.executable - _sort_args = (f"{py_exe} tools.py sort -i {in_path} -o {out_path} -s {sortby} -fp {method} " - f"-g {groupby} -k") + _sort_args = (f"{py_exe} tools.py sort -i {in_path} -o {out_path} -s {sortby} -g {groupby} -k") return _sort_args.split() @@ -183,7 +182,7 @@ def main(): "Sort faces.", sort_args( pathjoin(vid_base, "faces"), pathjoin(vid_base, "faces_sorted"), - sortby="face", method="rename" + sortby="face" ) ) diff --git a/tests/tools/alignments/media_test.py b/tests/tools/alignments/media_test.py index 52124a44..5f603ee3 100644 --- a/tests/tools/alignments/media_test.py +++ b/tests/tools/alignments/media_test.py @@ -1,8 +1,10 @@ #!/usr/bin python3 """ Pytest unit tests for :mod:`tools.alignments.media` """ +from __future__ import annotations import os +import typing as T + from operator import itemgetter -from typing import cast, Dict, Generator, List, Tuple from unittest.mock import MagicMock import cv2 @@ -19,6 +21,9 @@ from lib.utils import FaceswapError # noqa:E402 from tools.alignments.media import (AlignmentData, Faces, ExtractedFaces, # noqa:E402 Frames, MediaLoader) +if T.TYPE_CHECKING: + from collections.abc import Generator + class TestAlignmentData: """ Test for :class:`~tools.alignments.media.AlignmentData` """ @@ -224,8 +229,8 @@ class TestMediaLoader: """ media_loader = media_loader_instance expected = np.random.rand(256, 256, 3) - media_loader.load_video_frame = cast(MagicMock, # type:ignore - mocker.MagicMock(return_value=expected)) + media_loader.load_video_frame = T.cast(MagicMock, # type:ignore + mocker.MagicMock(return_value=expected)) read_image_patch = mocker.patch("tools.alignments.media.read_image", return_value=expected) filename = "test.png" output = media_loader.load_image(filename) @@ -263,7 +268,7 @@ class TestMediaLoader: vid_cap = mocker.MagicMock(cv2.VideoCapture) vid_cap.read.side_effect = ((1, expected), ) - media_loader._vid_reader = cast(MagicMock, vid_cap) # type:ignore + media_loader._vid_reader = T.cast(MagicMock, vid_cap) # type:ignore output = media_loader.load_video_frame(filename) vid_cap.set.assert_called_once() np.testing.assert_equal(output, expected) @@ -440,9 +445,9 @@ class TestFaces: src_filename = "test_0001.png" src_face_idx = 0 paths = [os.path.join(faces.folder, fname) for fname in os.listdir(faces.folder)] - data = dict(source=dict(source_filename=src_filename, - face_index=src_face_idx)) - seen: Dict[str, List[int]] = {} + data = {"source": {"source_filename": src_filename, + "face_index": src_face_idx}} + seen: dict[str, list[int]] = {} # New item is_dupe = faces._handle_duplicate(paths[0], data, seen) # type:ignore @@ -477,7 +482,7 @@ class TestFaces: faces = faces_instance read_image_meta_mock = mocker.patch("tools.alignments.media.read_image_meta_batch") img_sources = [os.path.join(faces.folder, fname) for fname in os.listdir(faces.folder)] - meta_data = dict(itxt=dict(source=(dict(source_filename="data.png")))) + meta_data = {"itxt": {"source": ({"source_filename": "data.png"})}} expected = [(fname, meta_data["itxt"]) for fname in os.listdir(faces.folder)] read_image_meta_mock.side_effect = [[(src, meta_data) for src in img_sources]] @@ -527,16 +532,16 @@ class TestFaces: The class instance for testing """ faces = faces_instance - data = [(f"file{idx}.png", dict(source=dict(source_filename=f"src{idx}.png", - face_index=0))) + data = [(f"file{idx}.png", {"source": {"source_filename": f"src{idx}.png", + "face_index": 0}}) for idx in range(4)] faces.file_list_sorted = data # type: ignore expected = {"src0.png": [0], "src1.png": [0], "src2.png": [0], "src3.png": [0]} result = faces.load_items() assert result == expected - data = [(f"file{idx}.png", dict(source=dict(source_filename=f"src{idx // 2}.png", - face_index=0 if idx % 2 == 0 else 1))) + data = [(f"file{idx}.png", {"source": {"source_filename": f"src{idx // 2}.png", + "face_index": 0 if idx % 2 == 0 else 1}}) for idx in range(4)] faces.file_list_sorted = data # type: ignore expected = {"src0.png": [0, 1], "src1.png": [0, 1]} @@ -556,7 +561,7 @@ class TestFaces: Fixture for mocking various logic calls """ faces = faces_instance - data: List[Tuple[str, dict]] = [("file4.png", {}), ("file3.png", {}), + data: list[tuple[str, dict]] = [("file4.png", {}), ("file3.png", {}), ("file1.png", {}), ("file2.png", {})] expected = sorted(data) process_folder_mock = mocker.patch("tools.alignments.media.Faces.process_folder", @@ -605,8 +610,8 @@ class TestFrames: folder : str Dummy media folder """ - expected = [dict(frame_fullname="a.png", frame_name="a", frame_extension=".png"), - dict(frame_fullname="b.png", frame_name="b", frame_extension=".png")] + expected = [{"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}, + {"frame_fullname": "b.png", "frame_name": "b", "frame_extension": ".png"}] frames = Frames(folder, None) returned = sorted(list(frames.process_frames()), key=itemgetter("frame_fullname")) @@ -620,12 +625,12 @@ class TestFrames: folder : str Dummy media folder """ - expected = [dict(frame_fullname="images_000001.png", - frame_name="images_000001", - frame_extension=".png"), - dict(frame_fullname="images_000002.png", - frame_name="images_000002", - frame_extension=".png")] + expected = [{"frame_fullname": "images_000001.png", + "frame_name": "images_000001", + "frame_extension": ".png"}, + {"frame_fullname": "images_000002.png", + "frame_name": "images_000002", + "frame_extension": ".png"}] frames = Frames(folder, None) returned = list(frames.process_video()) @@ -657,14 +662,14 @@ class TestFrames: Fixture for mocking process_folder call """ frames = Frames(folder, None) - data = [dict(frame_fullname="c.png", frame_name="c", frame_extension=".png"), - dict(frame_fullname="d.png", frame_name="d", frame_extension=".png"), - dict(frame_fullname="b.jpg", frame_name="b", frame_extension=".jpg"), - dict(frame_fullname="a.png", frame_name="a", frame_extension=".png")] - expected = [dict(frame_fullname="a.png", frame_name="a", frame_extension=".png"), - dict(frame_fullname="b.jpg", frame_name="b", frame_extension=".jpg"), - dict(frame_fullname="c.png", frame_name="c", frame_extension=".png"), - dict(frame_fullname="d.png", frame_name="d", frame_extension=".png")] + data = [{"frame_fullname": "c.png", "frame_name": "c", "frame_extension": ".png"}, + {"frame_fullname": "d.png", "frame_name": "d", "frame_extension": ".png"}, + {"frame_fullname": "b.jpg", "frame_name": "b", "frame_extension": ".jpg"}, + {"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}] + expected = [{"frame_fullname": "a.png", "frame_name": "a", "frame_extension": ".png"}, + {"frame_fullname": "b.jpg", "frame_name": "b", "frame_extension": ".jpg"}, + {"frame_fullname": "c.png", "frame_name": "c", "frame_extension": ".png"}, + {"frame_fullname": "d.png", "frame_name": "d", "frame_extension": ".png"}] process_folder_mock = mocker.patch("tools.alignments.media.Frames.process_folder", side_effect=[data]) result = frames.sorted_items() @@ -796,7 +801,7 @@ class TestExtractedFaces: Fixture for mocking get_faces method """ faces = extracted_faces_instance - faces.get_faces = cast(MagicMock, mocker.MagicMock()) # type:ignore + faces.get_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore frame = "test_frame" img = None @@ -837,7 +842,7 @@ class TestExtractedFaces: The expected output for the given ROI box """ faces = extracted_faces_instance - faces.get_faces = cast(MagicMock, mocker.MagicMock()) # type:ignore + faces.get_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore frame = "test_frame" faces.get_roi_size_for_frame(frame) diff --git a/tests/tools/preview/viewer_test.py b/tests/tools/preview/viewer_test.py index 05fb072e..18fd84d4 100644 --- a/tests/tools/preview/viewer_test.py +++ b/tests/tools/preview/viewer_test.py @@ -1,9 +1,11 @@ #!/usr/bin python3 """ Pytest unit tests for :mod:`tools.preview.viewer` """ +from __future__ import annotations import tkinter as tk +import typing as T + from tkinter import ttk -from typing import cast, TYPE_CHECKING from unittest.mock import MagicMock import pytest @@ -18,7 +20,7 @@ log_setup("DEBUG", "pytest_viewer.log", "PyTest, False") from lib.utils import get_backend # pylint:disable=wrong-import-position # noqa from tools.preview.viewer import _Faces, FacesDisplay, ImagesCanvas # pylint:disable=wrong-import-position # noqa -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.aligned_face import CenteringType @@ -104,7 +106,7 @@ class TestFacesDisplay(): """ Test :class:`~tools.preview.viewer.FacesDisplay` set_centering method """ f_display = self.get_faces_display_instance() assert f_display._centering is None - centering: "CenteringType" = "legacy" + centering: CenteringType = "legacy" f_display.set_centering(centering) assert f_display._centering == centering @@ -133,9 +135,9 @@ class TestFacesDisplay(): Mocker for checking _build_faces_image method called """ f_display = self.get_faces_display_instance(columns, face_size) - f_display._build_faces_image = cast(MagicMock, mocker.MagicMock()) # type:ignore - f_display._get_scale_size = cast(MagicMock, # type:ignore - mocker.MagicMock(return_value=(128, 128))) + f_display._build_faces_image = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._get_scale_size = T.cast(MagicMock, # type:ignore + mocker.MagicMock(return_value=(128, 128))) f_display._faces_source = np.zeros((face_size, face_size, 3), dtype=np.uint8) f_display._faces_dest = np.zeros((face_size, face_size, 3), dtype=np.uint8) @@ -186,12 +188,12 @@ class TestFacesDisplay(): header_size = 32 f_display = self.get_faces_display_instance(columns, face_size) - f_display._faces_from_frames = cast(MagicMock, mocker.MagicMock()) # type:ignore - f_display._header_text = cast( # type:ignore + f_display._faces_from_frames = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._header_text = T.cast( # type:ignore MagicMock, mocker.MagicMock(return_value=np.random.rand(header_size, face_size * columns, 3))) - f_display._draw_rect = cast(MagicMock, # type:ignore - mocker.MagicMock(side_effect=lambda x: x)) + f_display._draw_rect = T.cast(MagicMock, # type:ignore + mocker.MagicMock(side_effect=lambda x: x)) # Test full update f_display.update_source = True @@ -235,8 +237,8 @@ class TestFacesDisplay(): f_display = self.get_faces_display_instance(columns, face_size) f_display.source = [mocker.MagicMock() for _ in range(3)] f_display.destination = [np.random.rand(face_size, face_size, 3) for _ in range(3)] - f_display._crop_source_faces = cast(MagicMock, mocker.MagicMock()) # type:ignore - f_display._crop_destination_faces = cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._crop_source_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore + f_display._crop_destination_faces = T.cast(MagicMock, mocker.MagicMock()) # type:ignore # Both src + dst f_display.update_source = True @@ -451,7 +453,7 @@ class TestImagesCanvas: Mocker for dummying in tk calls """ event_mock = mocker.MagicMock(spec=tk.Event, width=100, height=200) - images_canvas_instance.reload = cast(MagicMock, mocker.MagicMock()) # type:ignore + images_canvas_instance.reload = T.cast(MagicMock, mocker.MagicMock()) # type:ignore images_canvas_instance._resize(event_mock) diff --git a/tests/utils.py b/tests/utils.py index 26379587..b357dc13 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,6 @@ """ Utils imported from Keras as their location changes between Tensorflow Keras and standard Keras. Also ensures testing consistency """ import inspect -import sys import numpy as np @@ -101,12 +100,6 @@ def has_arg(func, name, accept_all=False): bool Whether `func` accepts a `name` keyword argument. """ - if sys.version_info < (3, 3): - arg_spec = inspect.getfullargspec(func) - if accept_all and arg_spec.varkw is not None: - return True - return (name in arg_spec.args or - name in arg_spec.kwonlyargs) signature = inspect.signature(func) parameter = signature.parameters.get(name) if parameter is None: diff --git a/tools.py b/tools.py index 326e15be..c47798c7 100755 --- a/tools.py +++ b/tools.py @@ -9,17 +9,13 @@ from importlib import import_module # Importing the various tools from lib.cli.args import FullHelpArgumentParser - # LOCALES _LANG = gettext.translation("tools", localedir="locales", fallback=True) _ = _LANG.gettext - # Python version check -if sys.version_info[0] < 3: - raise Exception("This program requires at least python3.7") -if sys.version_info[0] == 3 and sys.version_info[1] < 7: - raise Exception("This program requires at least python3.7") +if sys.version_info < (3, 10): + raise ValueError("This program requires at least python 3.10") def bad_args(*args): # pylint:disable=unused-argument diff --git a/tools/alignments/alignments.py b/tools/alignments/alignments.py index 610e8857..6ffc3f84 100644 --- a/tools/alignments/alignments.py +++ b/tools/alignments/alignments.py @@ -3,10 +3,10 @@ import logging import os import sys +import typing as T from argparse import Namespace from multiprocessing import Process -from typing import Any, cast, List, Dict, Optional from lib.utils import _video_extensions, FaceswapError from .media import AlignmentData @@ -66,7 +66,7 @@ class Alignments(): # pylint:disable=too-few-public-methods logger.debug("Running in batch mode") return batch_mode - def _get_alignments_locations(self) -> Dict[str, List[Optional[str]]]: + def _get_alignments_locations(self) -> dict[str, list[str | None]]: """ Obtain the full path to alignments files in a parent (batch) location These are jobs that only require an alignments file as input, so frames and face locations @@ -92,12 +92,12 @@ class Alignments(): # pylint:disable=too-few-public-methods sys.exit(1) logger.info("Batch mode selected. Processing alignments: %s", alignments) - retval = dict(alignments_file=alignments, - faces_dir=[None for _ in range(len(alignments))], - frames_dir=[None for _ in range(len(alignments))]) + retval = {"alignments_file": alignments, + "faces_dir": [None for _ in range(len(alignments))], + "frames_dir": [None for _ in range(len(alignments))]} return retval - def _get_frames_locations(self) -> Dict[str, List[Optional[str]]]: + def _get_frames_locations(self) -> dict[str, list[str | None]]: """ Obtain the full path to frame locations along with corresponding alignments file locations contained within the parent (batch) location @@ -138,7 +138,7 @@ class Alignments(): # pylint:disable=too-few-public-methods sys.exit(1) if self._args.job not in self._requires_faces: # faces not required for frames input - faces: list[Optional[str]] = [None for _ in range(len(frames))] + faces: list[str | None] = [None for _ in range(len(frames))] else: if not self._args.faces_dir: logger.error("Please provide a 'faces_dir' location for '%s' job", self._args.job) @@ -149,11 +149,11 @@ class Alignments(): # pylint:disable=too-few-public-methods logger.info("Batch mode selected. Processing frames: %s", [os.path.basename(frame) for frame in frames]) - return dict(alignments_file=cast(List[Optional[str]], alignments), - frames_dir=cast(List[Optional[str]], frames), - faces_dir=faces) + return {"alignments_file": T.cast(list[str | None], alignments), + "frames_dir": T.cast(list[str | None], frames), + "faces_dir": faces} - def _get_locations(self) -> Dict[str, List[Optional[str]]]: + def _get_locations(self) -> dict[str, list[str | None]]: """ Obtain the full path to any frame, face and alignments input locations for the selected job when running in batch mode. If not running in batch mode, then the original passed in values are returned in lists @@ -166,9 +166,9 @@ class Alignments(): # pylint:disable=too-few-public-methods """ job: str = self._args.job if not self._batch_mode: # handle with given arguments - retval = dict(alignments_file=[self._args.alignments_file], - faces_dir=[self._args.faces_dir], - frames_dir=[self._args.frames_dir]) + retval = {"alignments_file": [self._args.alignments_file], + "faces_dir": [self._args.faces_dir], + "frames_dir": [self._args.frames_dir]} elif job in self._requires_alignments: # Jobs only requiring an alignments file location retval = self._get_alignments_locations() @@ -185,9 +185,9 @@ class Alignments(): # pylint:disable=too-few-public-methods logger.error("No folders found in '%s'", self._args.faces_dir) sys.exit(1) - retval = dict(faces_dir=faces, - frames_dir=[None for _ in range(len(faces))], - alignments_file=[None for _ in range(len(faces))]) + retval = {"faces_dir": faces, + "frames_dir": [None for _ in range(len(faces))], + "alignments_file": [None for _ in range(len(faces))]} logger.info("Batch mode selected. Processing faces: %s", [os.path.basename(folder) for folder in faces]) else: @@ -306,7 +306,7 @@ class _Alignments(): # pylint:disable=too-few-public-methods Launches the selected alignments job. """ if self._args.job in ("missing-alignments", "missing-frames", "multi-faces", "no-faces"): - job: Any = Check + job: T.Any = Check else: job = globals()[self._args.job.title().replace("-", "")] job = job(self.alignments, self._args) diff --git a/tools/alignments/cli.py b/tools/alignments/cli.py index 85698706..d41b4df4 100644 --- a/tools/alignments/cli.py +++ b/tools/alignments/cli.py @@ -2,8 +2,7 @@ """ Command Line Arguments for tools """ import sys import gettext - -from typing import Any, List, Dict +import typing as T from lib.cli.args import FaceSwapArgs from lib.cli.actions import DirOrFileFullPaths, DirFullPaths, FileFullPaths, Radio, Slider @@ -33,7 +32,7 @@ class AlignmentsArgs(FaceSwapArgs): "an alignments file against its corresponding faceset/frame source.") @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Collect the argparse argument options. Returns diff --git a/tools/alignments/jobs.py b/tools/alignments/jobs.py index 36d1d685..2dfdc9a7 100644 --- a/tools/alignments/jobs.py +++ b/tools/alignments/jobs.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 """ Tools for manipulating the alignments serialized file """ - +from __future__ import annotations import logging import os import sys +import typing as T + from datetime import datetime -from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING, Optional, Union import numpy as np from scipy import signal @@ -15,12 +16,8 @@ from tqdm import tqdm from .media import Faces, Frames from .jobs_faces import FaceToFile -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator from argparse import Namespace from lib.align.alignments import PNGHeaderDict from .media import AlignmentData @@ -38,11 +35,11 @@ class Check(): arguments: :class:`argparse.Namespace` The command line arguments that have called this job """ - def __init__(self, alignments: "AlignmentData", arguments: "Namespace") -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._alignments = alignments self._job = arguments.job - self._type: Optional[Literal["faces", "frames"]] = None + self._type: T.Literal["faces", "frames"] | None = None self._is_video = False # Set when getting items self._output = arguments.output self._source_dir = self._get_source_dir(arguments) @@ -52,7 +49,7 @@ class Check(): self.output_message = "" logger.debug("Initialized %s", self.__class__.__name__) - def _get_source_dir(self, arguments: "Namespace") -> str: + def _get_source_dir(self, arguments: Namespace) -> str: """ Set the correct source folder Parameters @@ -81,7 +78,7 @@ class Check(): logger.debug("type: '%s', source_dir: '%s'", self._type, source_dir) return source_dir - def _get_items(self) -> Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]]: + def _get_items(self) -> list[dict[str, str]] | list[tuple[str, PNGHeaderDict]]: """ Set the correct items to process Returns @@ -92,10 +89,10 @@ class Check(): the dictionaries will contain the keys 'frame_fullname', 'frame_name', 'extension'. """ assert self._type is not None - items: Union[Frames, Faces] = globals()[self._type.title()](self._source_dir) + items: Frames | Faces = globals()[self._type.title()](self._source_dir) self._is_video = items.is_video - return cast(Union[List[Dict[str, str]], List[Tuple[str, "PNGHeaderDict"]]], - items.file_list_sorted) + return T.cast(list[dict[str, str]] | list[tuple[str, "PNGHeaderDict"]], + items.file_list_sorted) def process(self) -> None: """ Process the frames check against the alignments file """ @@ -104,7 +101,7 @@ class Check(): items_output = self._compile_output() if self._type == "faces": - filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._items) + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._items) check_update = FaceToFile(self._alignments, [val[1] for val in filelist]) if check_update(): self._alignments.save() @@ -122,7 +119,7 @@ class Check(): "supported for 'multi-faces'") sys.exit(1) - def _compile_output(self) -> Union[List[str], List[Tuple[str, int]]]: + def _compile_output(self) -> list[str] | list[tuple[str, int]]: """ Compile list of frames that meet criteria Returns @@ -144,7 +141,7 @@ class Check(): The frame name of any frames which have no faces """ self.output_message = "Frames with no faces" - for frame in tqdm(cast(List[Dict[str, str]], self._items), + for frame in tqdm(T.cast(list[dict[str, str]], self._items), desc=self.output_message, leave=False): logger.trace(frame) # type:ignore @@ -153,8 +150,8 @@ class Check(): logger.debug("Returning: '%s'", frame_name) yield frame_name - def _get_multi_faces(self) -> Union[Generator[str, None, None], - Generator[Tuple[str, int], None, None]]: + def _get_multi_faces(self) -> (Generator[str, None, None] | + Generator[tuple[str, int], None, None]): """ yield each frame or face that has multiple faces matched in alignments file Yields @@ -175,7 +172,7 @@ class Check(): The frame name of any frames which have multiple faces """ self.output_message = "Frames with multiple faces" - for item in tqdm(cast(List[Dict[str, str]], self._items), + for item in tqdm(T.cast(list[dict[str, str]], self._items), desc=self.output_message, leave=False): filename = item["frame_fullname"] @@ -184,7 +181,7 @@ class Check(): logger.trace("Returning: '%s'", filename) # type:ignore yield filename - def _get_multi_faces_faces(self) -> Generator[Tuple[str, int], None, None]: + def _get_multi_faces_faces(self) -> Generator[tuple[str, int], None, None]: """ Return Faces when there are multiple faces in a frame Yields @@ -193,7 +190,7 @@ class Check(): The frame name and the face id of any frames which have multiple faces """ self.output_message = "Multiple faces in frame" - for item in tqdm(cast(List[Tuple[str, "PNGHeaderDict"]], self._items), + for item in tqdm(T.cast(list[tuple[str, "PNGHeaderDict"]], self._items), desc=self.output_message, leave=False): src = item[1]["source"] @@ -213,7 +210,7 @@ class Check(): """ self.output_message = "Frames missing from alignments file" exclude_filetypes = set(["yaml", "yml", "p", "json", "txt"]) - for frame in tqdm(cast(Dict[str, str], self._items), + for frame in tqdm(T.cast(dict[str, str], self._items), desc=self.output_message, leave=False): frame_name = frame["frame_fullname"] @@ -231,13 +228,13 @@ class Check(): The frame name of any frames in alignments with no matching file """ self.output_message = "Missing frames that are in alignments file" - frames = set(item["frame_fullname"] for item in cast(List[Dict[str, str]], self._items)) + frames = set(item["frame_fullname"] for item in T.cast(list[dict[str, str]], self._items)) for frame in tqdm(self._alignments.data.keys(), desc=self.output_message, leave=False): if frame not in frames: logger.debug("Returning: '%s'", frame) yield frame - def _output_results(self, items_output: Union[List[str], List[Tuple[str, int]]]) -> None: + def _output_results(self, items_output: list[str] | list[tuple[str, int]]) -> None: """ Output the results in the requested format Parameters @@ -261,7 +258,7 @@ class Check(): # Strip the index for printed/file output final_output = [item[0] for item in items_output] else: - final_output = cast(List[str], items_output) + final_output = T.cast(list[str], items_output) output_message = "-----------------------------------------------\r\n" output_message += f" {self.output_message} ({len(final_output)})\r\n" output_message += "-----------------------------------------------\r\n" @@ -316,7 +313,7 @@ class Check(): with open(output_file, "w", encoding="utf8") as f_output: f_output.write(output_message) - def _move_file(self, items_output: Union[List[str], List[Tuple[str, int]]]) -> None: + def _move_file(self, items_output: list[str] | list[tuple[str, int]]) -> None: """ Move the identified frames to a new sub folder Parameters @@ -335,7 +332,7 @@ class Check(): logger.debug("Move function: %s", move) move(output_folder, items_output) - def _move_frames(self, output_folder: str, items_output: List[str]) -> None: + def _move_frames(self, output_folder: str, items_output: list[str]) -> None: """ Move frames into single sub folder Parameters @@ -352,7 +349,7 @@ class Check(): logger.debug("Moving: '%s' to '%s'", src, dst) os.rename(src, dst) - def _move_faces(self, output_folder: str, items_output: List[Tuple[str, int]]) -> None: + def _move_faces(self, output_folder: str, items_output: list[tuple[str, int]]) -> None: """ Make additional sub folders for each face that appears Enables easier manual sorting Parameters @@ -384,7 +381,7 @@ class Sort(): arguments: :class:`argparse.Namespace` The :mod:`argparse` arguments as passed in from :mod:`tools.py` """ - def __init__(self, alignments: "AlignmentData", arguments: "Namespace") -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._alignments = alignments logger.debug("Initialized %s", self.__class__.__name__) @@ -435,13 +432,13 @@ class Spatial(): # pylint:disable=too-few-public-methods --------- https://www.kaggle.com/selfishgene/animating-and-smoothing-3d-facial-keypoints/notebook """ - def __init__(self, alignments: "AlignmentData", arguments: "Namespace") -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self.arguments = arguments self._alignments = alignments - self._mappings: Dict[int, str] = {} - self._normalized: Dict[str, np.ndarray] = {} - self._shapes_model: Optional[decomposition.PCA] = None + self._mappings: dict[int, str] = {} + self._normalized: dict[str, np.ndarray] = {} + self._shapes_model: decomposition.PCA | None = None logger.debug("Initialized %s", self.__class__.__name__) def process(self) -> None: @@ -464,7 +461,7 @@ class Spatial(): # pylint:disable=too-few-public-methods # Define shape normalization utility functions @staticmethod def _normalize_shapes(shapes_im_coords: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Normalize a 2D or 3D shape Parameters diff --git a/tools/alignments/jobs_faces.py b/tools/alignments/jobs_faces.py index 6be619f9..06fcf856 100644 --- a/tools/alignments/jobs_faces.py +++ b/tools/alignments/jobs_faces.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 """ Tools for manipulating the alignments using extracted Faces as a source """ +from __future__ import annotations import logging import os -import sys +import typing as T + from argparse import Namespace from operator import itemgetter -from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING import numpy as np from tqdm import tqdm @@ -16,12 +17,7 @@ from scripts.fsmedia import Alignments from .media import Faces -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from .media import AlignmentData from lib.align.alignments import (AlignmentDict, AlignmentFileDict, PNGHeaderDict, PNGHeaderAlignmentsDict) @@ -50,9 +46,9 @@ class FromFaces(): # pylint:disable=too-few-public-methods """ Run the job to read faces from a folder to create alignments file(s). """ logger.info("[CREATE ALIGNMENTS FROM FACES]") # Tidy up cli output - all_versions: Dict[str, List[float]] = {} - d_align: Dict[str, Dict[str, List[Tuple[int, "AlignmentFileDict", str, dict]]]] = {} - filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + all_versions: dict[str, list[float]] = {} + d_align: dict[str, dict[str, list[tuple[int, AlignmentFileDict, str, dict]]]] = {} + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) for filename, meta in tqdm(filelist, desc="Generating Alignments", total=len(filelist), @@ -93,7 +89,7 @@ class FromFaces(): # pylint:disable=too-few-public-methods logger.trace("Extracted alignments file filename: '%s'", retval) # type:ignore return retval - def _extract_alignment(self, metadata: dict) -> Tuple[str, int, "AlignmentFileDict"]: + def _extract_alignment(self, metadata: dict) -> tuple[str, int, AlignmentFileDict]: """ Extract alignment data from a PNG image's itxt header. Formats the landmarks into a numpy array and adds in mask centering information if it is @@ -123,11 +119,11 @@ class FromFaces(): # pylint:disable=too-few-public-methods return frame_name, face_index, alignment def _sort_alignments(self, - alignments: Dict[str, Dict[str, List[Tuple[int, - "AlignmentFileDict", + alignments: dict[str, dict[str, list[tuple[int, + AlignmentFileDict, str, dict]]]] - ) -> Dict[str, Dict[str, "AlignmentDict"]]: + ) -> dict[str, dict[str, AlignmentDict]]: """ Sort the faces into face index order as they appeared in the original alignments file. If the face index stored in the png header does not match it's position in the alignments @@ -147,11 +143,11 @@ class FromFaces(): # pylint:disable=too-few-public-methods The alignments file dictionaries sorted into the correct face order, ready for saving """ logger.info("Sorting and checking faces...") - aln_sorted: Dict[str, Dict[str, "AlignmentDict"]] = {} + aln_sorted: dict[str, dict[str, AlignmentDict]] = {} for fname, frames in alignments.items(): - this_file: Dict[str, "AlignmentDict"] = {} + this_file: dict[str, AlignmentDict] = {} for frame in tqdm(sorted(frames), desc=f"Sorting {fname}", leave=False): - this_file[frame] = dict(video_meta={}, faces=[]) + this_file[frame] = {"video_meta": {}, "faces": []} for real_idx, (f_id, almt, f_path, f_src) in enumerate(sorted(frames[frame], key=itemgetter(0))): if real_idx != f_id: @@ -165,7 +161,7 @@ class FromFaces(): # pylint:disable=too-few-public-methods def _update_png_header(cls, face_path: str, new_index: int, - alignment: "AlignmentFileDict", + alignment: AlignmentFileDict, source_info: dict) -> None: """ Update the PNG header for faces where the stored index does not correspond with the alignments file. This can occur when frames with multiple faces have had some faces deleted @@ -194,12 +190,12 @@ class FromFaces(): # pylint:disable=too-few-public-methods source_info["face_index"] = new_index source_info["original_filename"] = new_filename - meta = dict(alignments=face.to_png_meta(), source=source_info) + meta = {"alignments": face.to_png_meta(), "source": source_info} update_existing_metadata(face_path, meta) def _save_alignments(self, - all_alignments: Dict[str, Dict[str, "AlignmentDict"]], - versions: Dict[str, float]) -> None: + all_alignments: dict[str, dict[str, AlignmentDict]], + versions: dict[str, float]) -> None: """ Save the newely generated alignments file(s). If an alignments file already exists in the source faces folder, back it up rather than @@ -240,9 +236,9 @@ class Rename(): # pylint:disable=too-few-public-methods Default: ``None`` """ def __init__(self, - alignments: "AlignmentData", - arguments: Optional[Namespace], - faces: Optional[Faces] = None) -> None: + alignments: AlignmentData, + arguments: Namespace | None, + faces: Faces | None = None) -> None: logger.debug("Initializing %s: (arguments: %s, faces: %s)", self.__class__.__name__, arguments, faces) self._alignments = alignments @@ -261,7 +257,7 @@ class Rename(): # pylint:disable=too-few-public-methods def process(self) -> None: """ Process the face renaming """ logger.info("[RENAME FACES]") # Tidy up cli output - filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) rename_mappings = sorted([(face[0], face[1]["source"]["original_filename"]) for face in filelist if face[0] != face[1]["source"]["original_filename"]], @@ -269,12 +265,12 @@ class Rename(): # pylint:disable=too-few-public-methods rename_count = self._rename_faces(rename_mappings) logger.info("%s faces renamed", rename_count) - filelist = cast(List[Tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) + filelist = T.cast(list[tuple[str, "PNGHeaderDict"]], self._faces.file_list_sorted) copyback = FaceToFile(self._alignments, [val[1] for val in filelist]) if copyback(): self._alignments.save() - def _rename_faces(self, filename_mappings: List[Tuple[str, str]]) -> int: + def _rename_faces(self, filename_mappings: list[tuple[str, str]]) -> int: """ Rename faces back to their original name as exists in the alignments file. If the source and destination filename are the same then skip that file. @@ -333,7 +329,7 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods arguments: :class:`argparse.Namespace` The command line arguments that have called this job """ - def __init__(self, alignments: "AlignmentData", arguments: Namespace) -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._alignments = alignments @@ -350,7 +346,7 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods "faces from your alignments file. Process aborted.") return - items = cast(Dict[str, List[int]], self._items.items) + items = T.cast(dict[str, list[int]], self._items.items) pre_face_count = self._alignments.faces_count self._alignments.filter_faces(items, filter_out=False) del_count = pre_face_count - self._alignments.faces_count @@ -375,9 +371,9 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods to like this and has a tendency to throw permission errors, so this remains single threaded for now. """ - items = cast(Dict[str, List[int]], self._items.items) + items = T.cast(dict[str, list[int]], self._items.items) srcs = [(x[0], x[1]["source"]) - for x in cast(List[Tuple[str, "PNGHeaderDict"]], self._items.file_list_sorted)] + for x in T.cast(list[tuple[str, "PNGHeaderDict"]], self._items.file_list_sorted)] to_update = [ # Items whose face index has changed x for x in srcs if x[1]["face_index"] != items[x[1]["source_filename"]].index(x[1]["face_index"])] @@ -399,13 +395,13 @@ class RemoveFaces(): # pylint:disable=too-few-public-methods face = DetectedFace() face.from_alignment(self._alignments.get_faces_in_frame(frame)[new_index]) - meta = dict(alignments=face.to_png_meta(), - source=dict(alignments_version=file_info["alignments_version"], - original_filename=orig_filename, - face_index=new_index, - source_filename=frame, - source_is_video=file_info["source_is_video"], - source_frame_dims=file_info.get("source_frame_dims"))) + meta = {"alignments": face.to_png_meta(), + "source": {"alignments_version": file_info["alignments_version"], + "original_filename": orig_filename, + "face_index": new_index, + "source_filename": frame, + "source_is_video": file_info["source_is_video"], + "source_frame_dims": file_info.get("source_frame_dims")}} update_existing_metadata(fullpath, meta) logger.info("%s Extracted face(s) had their header information updated", len(to_update)) @@ -422,18 +418,18 @@ class FaceToFile(): # pylint:disable=too-few-public-methods face_data: list List of :class:`PNGHeaderDict` objects """ - def __init__(self, alignments: "AlignmentData", face_data: List["PNGHeaderDict"]) -> None: + def __init__(self, alignments: AlignmentData, face_data: list[PNGHeaderDict]) -> None: logger.debug("Initializing %s: alignments: %s, face_data: %s", self.__class__.__name__, alignments, len(face_data)) self._alignments = alignments self._face_alignments = face_data - self._updatable_keys: List[Literal["identity", "mask"]] = ["identity", "mask"] - self._counts: Dict[str, int] = {} + self._updatable_keys: list[T.Literal["identity", "mask"]] = ["identity", "mask"] + self._counts: dict[str, int] = {} logger.debug("Initialized %s", self.__class__.__name__) def _check_and_update(self, - alignment: "PNGHeaderAlignmentsDict", - face: "AlignmentFileDict") -> None: + alignment: PNGHeaderAlignmentsDict, + face: AlignmentFileDict) -> None: """ Check whether the key requires updating and update it. alignment: dict diff --git a/tools/alignments/jobs_frames.py b/tools/alignments/jobs_frames.py index 8df193e3..62ce578d 100644 --- a/tools/alignments/jobs_frames.py +++ b/tools/alignments/jobs_frames.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 """ Tools for manipulating the alignments using Frames as a source """ +from __future__ import annotations import logging import os import sys +import typing as T + from datetime import datetime -from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 import numpy as np @@ -16,12 +18,7 @@ from lib.image import encode_image, generate_thumbnail, ImagesSaver from plugins.extract.pipeline import Extractor, ExtractMedia from .media import ExtractedFaces, Frames -if sys.version_info < (3, 8): - from typing_extensions import get_args, Literal -else: - from typing import get_args, Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from argparse import Namespace from .media import AlignmentData @@ -39,19 +36,19 @@ class Draw(): # pylint:disable=too-few-public-methods arguments: :class:`argparse.Namespace` The command line arguments that have called this job """ - def __init__(self, alignments: "AlignmentData", arguments: "Namespace") -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._alignments = alignments self._frames = Frames(arguments.frames_dir) self._output_folder = self._set_output() - self._mesh_areas = dict(mouth=(48, 68), - right_eyebrow=(17, 22), - left_eyebrow=(22, 27), - right_eye=(36, 42), - left_eye=(42, 48), - nose=(27, 36), - jaw=(0, 17), - chin=(8, 11)) + self._mesh_areas = {"mouth": (48, 68), + "right_eyebrow": (17, 22), + "left_eyebrow": (22, 27), + "right_eye": (36, 42), + "left_eye": (42, 48), + "nose": (27, 36), + "jaw": (0, 17), + "chin": (8, 11)} logger.debug("Initialized %s", self.__class__.__name__) def _set_output(self) -> str: @@ -145,7 +142,7 @@ class Draw(): # pylint:disable=too-few-public-methods index: int The face index for the given face """ - for area in get_args(Literal["face", "head"]): + for area in T.get_args(T.Literal["face", "head"]): face.load_aligned(image, centering=area, force=True) color = (0, 255, 0) if area == "face" else (0, 0, 255) top_left = face.aligned.original_roi[0] @@ -184,12 +181,12 @@ class Extract(): # pylint:disable=too-few-public-methods arguments: :class:`argparse.Namespace` The :mod:`argparse` arguments as passed in from :mod:`tools.py` """ - def __init__(self, alignments: "AlignmentData", arguments: "Namespace") -> None: + def __init__(self, alignments: AlignmentData, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: %s)", self.__class__.__name__, arguments) self._arguments = arguments self._alignments = alignments self._is_legacy = self._alignments.version == 1.0 # pylint:disable=protected-access - self._mask_pipeline: Optional[Extractor] = None + self._mask_pipeline: Extractor | None = None self._faces_dir = arguments.faces_dir self._min_size = self._get_min_size(arguments.size, arguments.min_size) @@ -197,7 +194,7 @@ class Extract(): # pylint:disable=too-few-public-methods self._extracted_faces = ExtractedFaces(self._frames, self._alignments, size=arguments.size) - self._saver: Optional[ImagesSaver] = None + self._saver: ImagesSaver | None = None logger.debug("Initialized %s", self.__class__.__name__) @classmethod @@ -223,7 +220,7 @@ class Extract(): # pylint:disable=too-few-public-methods extract_size, min_size, retval) return retval - def _get_count(self) -> Optional[int]: + def _get_count(self) -> int | None: """ If the alignments file has been run through the manual tool, then it will hold video meta information, meaning that the count of frames in the alignment file can be relied on to be accurate. @@ -237,8 +234,7 @@ class Extract(): # pylint:disable=too-few-public-methods meta = self._alignments.video_meta_data has_meta = all(val is not None for val in meta.values()) if has_meta: - retval: Optional[int] = len(cast(Dict[str, Union[List[int], List[float]]], - meta["pts_time"])) + retval: int | None = len(T.cast(dict[str, list[int] | list[float]], meta["pts_time"])) else: retval = None logger.debug("Frame count from alignments file: (has_meta: %s, %s", has_meta, retval) @@ -320,7 +316,7 @@ class Extract(): # pylint:disable=too-few-public-methods self._alignments.save() logger.info("%s face(s) extracted", extracted_faces) - def _set_skip_list(self) -> Optional[List[int]]: + def _set_skip_list(self) -> list[int] | None: """ Set the indices for frames that should be skipped based on the `extract_every_n` command line option. @@ -335,7 +331,7 @@ class Extract(): # pylint:disable=too-few-public-methods logger.debug("Not skipping any frames") return None skip_list = [] - for idx, item in enumerate(cast(List[Dict[str, str]], self._frames.file_list_sorted)): + for idx, item in enumerate(T.cast(list[dict[str, str]], self._frames.file_list_sorted)): if idx % skip_num != 0: logger.trace("Adding image '%s' to skip list due to " # type:ignore "extract_every_n = %s", item["frame_fullname"], skip_num) @@ -370,14 +366,14 @@ class Extract(): # pylint:disable=too-few-public-methods for idx, face in enumerate(faces): output = f"{frame_name}_{idx}.png" - meta: PNGHeaderDict = dict( - alignments=face.to_png_meta(), - source=dict(alignments_version=self._alignments.version, - original_filename=output, - face_index=idx, - source_filename=filename, - source_is_video=self._frames.is_video, - source_frame_dims=cast(Tuple[int, int], image.shape[:2]))) + meta: PNGHeaderDict = { + "alignments": face.to_png_meta(), + "source": {"alignments_version": self._alignments.version, + "original_filename": output, + "face_index": idx, + "source_filename": filename, + "source_is_video": self._frames.is_video, + "source_frame_dims": T.cast(tuple[int, int], image.shape[:2])}} assert face.aligned.face is not None self._saver.save(output, encode_image(face.aligned.face, ".png", metadata=meta)) if self._min_size == 0 and self._is_legacy: @@ -387,7 +383,7 @@ class Extract(): # pylint:disable=too-few-public-methods self._saver.close() return face_count - def _select_valid_faces(self, frame: str, image: np.ndarray) -> List[DetectedFace]: + def _select_valid_faces(self, frame: str, image: np.ndarray) -> list[DetectedFace]: """ Return the aligned faces from a frame that meet the selection criteria, Parameters @@ -416,7 +412,7 @@ class Extract(): # pylint:disable=too-few-public-methods def _process_legacy(self, filename: str, image: np.ndarray, - detected_faces: List[DetectedFace]) -> List[DetectedFace]: + detected_faces: list[DetectedFace]) -> list[DetectedFace]: """ Process legacy face extractions to new extraction method. Updates stored masks to new extract size diff --git a/tools/alignments/media.py b/tools/alignments/media.py index ee471e2f..ee6bcc21 100644 --- a/tools/alignments/media.py +++ b/tools/alignments/media.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 """ Media items (Alignments, Faces, Frames) for alignments tool """ - +from __future__ import annotations import logging from operator import itemgetter import os import sys -from typing import cast, Generator, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +import typing as T import cv2 from tqdm import tqdm @@ -19,7 +19,8 @@ from lib.image import (count_frames, generate_thumbnail, ImagesLoader, png_write_meta, read_image, read_image_meta_batch) from lib.utils import _image_extensions, _video_extensions, FaceswapError -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Generator import numpy as np from lib.align.alignments import AlignmentFileDict, PNGHeaderDict @@ -44,7 +45,7 @@ class AlignmentData(Alignments): logger.debug("Initialized %s", self.__class__.__name__) @staticmethod - def check_file_exists(alignments_file: str) -> Tuple[str, str]: + def check_file_exists(alignments_file: str) -> tuple[str, str]: """ Check if the alignments file exists, and returns a tuple of the folder and filename. Parameters @@ -85,7 +86,7 @@ class MediaLoader(): analyzing a video file. If the count is not passed in, it will be calculated. Default: ``None`` """ - def __init__(self, folder: str, count: Optional[int] = None): + def __init__(self, folder: str, count: int | None = None): logger.debug("Initializing %s: (folder: '%s')", self.__class__.__name__, folder) logger.info("[%s DATA]", self.__class__.__name__.upper()) self._count = count @@ -112,7 +113,7 @@ class MediaLoader(): self._count = len(self.file_list_sorted) return self._count - def check_input_folder(self) -> Optional[cv2.VideoCapture]: + def check_input_folder(self) -> cv2.VideoCapture | None: """ Ensure that the frames or faces folder exists and is valid. If frames folder contains a video file return imageio reader object @@ -151,22 +152,20 @@ class MediaLoader(): logger.trace("Filename has valid extension: '%s': %s", filename, retval) # type: ignore return retval - def sorted_items(self) -> Union[List[Dict[str, str]], - List[Tuple[str, "PNGHeaderDict"]]]: + def sorted_items(self) -> list[dict[str, str]] | list[tuple[str, PNGHeaderDict]]: """ Override for specific folder processing """ raise NotImplementedError() - def process_folder(self) -> Union[Generator[Dict[str, str], None, None], - Generator[Tuple[str, "PNGHeaderDict"], None, None]]: + def process_folder(self) -> (Generator[dict[str, str], None, None] | + Generator[tuple[str, PNGHeaderDict], None, None]): """ Override for specific folder processing """ raise NotImplementedError() - def load_items(self) -> Union[Dict[str, List[int]], - Dict[str, Tuple[str, str]]]: + def load_items(self) -> dict[str, list[int]] | dict[str, tuple[str, str]]: """ Override for specific item loading """ raise NotImplementedError() - def load_image(self, filename: str) -> "np.ndarray": + def load_image(self, filename: str) -> np.ndarray: """ Load an image Parameters @@ -187,7 +186,7 @@ class MediaLoader(): image = read_image(src, raise_error=True) return image - def load_video_frame(self, filename: str) -> "np.ndarray": + def load_video_frame(self, filename: str) -> np.ndarray: """ Load a requested frame from video Parameters @@ -212,8 +211,8 @@ class MediaLoader(): # image = self._vid_reader.get_next_data()[:, :, ::-1] return image - def stream(self, skip_list: Optional[List[int]] = None - ) -> Generator[Tuple[str, "np.ndarray"], None, None]: + def stream(self, skip_list: list[int] | None = None + ) -> Generator[tuple[str, np.ndarray], None, None]: """ Load the images in :attr:`folder` in the order they are received from :class:`lib.image.ImagesLoader` in a background thread. @@ -239,8 +238,8 @@ class MediaLoader(): @staticmethod def save_image(output_folder: str, filename: str, - image: "np.ndarray", - metadata: Optional["PNGHeaderDict"] = None) -> None: + image: np.ndarray, + metadata: PNGHeaderDict | None = None) -> None: """ Save an image """ output_file = os.path.join(output_folder, filename) output_file = os.path.splitext(output_file)[0] + ".png" @@ -267,11 +266,11 @@ class Faces(MediaLoader): - When the remove-faces job is being run, when the process will only load faces that exist in the alignments file. Default: ``None`` """ - def __init__(self, folder: str, alignments: Optional[Alignments] = None) -> None: + def __init__(self, folder: str, alignments: Alignments | None = None) -> None: self._alignments = alignments super().__init__(folder) - def _handle_legacy(self, fullpath: str, log: bool = False) -> "PNGHeaderDict": + def _handle_legacy(self, fullpath: str, log: bool = False) -> PNGHeaderDict: """Handle facesets that are legacy (i.e. do not contain alignment information in the header data) @@ -311,8 +310,8 @@ class Faces(MediaLoader): def _handle_duplicate(self, fullpath: str, - header_dict: "PNGHeaderDict", - seen: Dict[str, List[int]]) -> bool: + header_dict: PNGHeaderDict, + seen: dict[str, list[int]]) -> bool: """ Check whether the given face has already been seen for the source frame and face index from an existing face. Can happen when filenames have changed due to sorting etc. and users have done multiple extractions/copies and placed all of the faces in the same folder @@ -323,7 +322,7 @@ class Faces(MediaLoader): The full path to the face image that is being checked header_dict : class:`~lib.align.alignments.PNGHeaderDict` The PNG header dictionary for the given face - seen : Dict[str, List[int]] + seen : dict[str, list[int]] Dictionary of original source filename and face indices that have already been seen and will be updated with the face processing now @@ -346,7 +345,7 @@ class Faces(MediaLoader): seen.setdefault(src_filename, []).append(face_index) return False - def process_folder(self) -> Generator[Tuple[str, "PNGHeaderDict"], None, None]: + def process_folder(self) -> Generator[tuple[str, PNGHeaderDict], None, None]: """ Iterate through the faces folder pulling out various information for each face. Yields @@ -358,7 +357,7 @@ class Faces(MediaLoader): logger.info("Loading file list from %s", self.folder) filter_count = 0 dupe_count = 0 - seen: Dict[str, List[int]] = {} + seen: dict[str, list[int]] = {} if self._alignments is not None and self._alignments.version < 2.1: # Legacy updating filelist = [os.path.join(self.folder, face) @@ -378,7 +377,7 @@ class Faces(MediaLoader): sub_dict = self._handle_legacy(fullpath, not log_once) log_once = True else: - sub_dict = cast("PNGHeaderDict", metadata["itxt"]) + sub_dict = T.cast("PNGHeaderDict", metadata["itxt"]) if self._handle_duplicate(fullpath, sub_dict, seen): dupe_count += 1 @@ -401,7 +400,7 @@ class Faces(MediaLoader): "'%s' from where they can be safely deleted", dupe_count, os.path.join(self.folder, "_duplicates")) - def load_items(self) -> Dict[str, List[int]]: + def load_items(self) -> dict[str, list[int]]: """ Load the face names into dictionary. Returns @@ -409,14 +408,14 @@ class Faces(MediaLoader): dict The source filename as key with list of face indices for the frame as value """ - faces: Dict[str, List[int]] = {} - for face in cast(List[Tuple[str, "PNGHeaderDict"]], self.file_list_sorted): + faces: dict[str, list[int]] = {} + for face in T.cast(list[tuple[str, "PNGHeaderDict"]], self.file_list_sorted): src = face[1]["source"] faces.setdefault(src["source_filename"], []).append(src["face_index"]) logger.trace(faces) # type: ignore return faces - def sorted_items(self) -> List[Tuple[str, "PNGHeaderDict"]]: + def sorted_items(self) -> list[tuple[str, PNGHeaderDict]]: """ Return the items sorted by the saved file name. Returns @@ -432,7 +431,7 @@ class Faces(MediaLoader): class Frames(MediaLoader): """ Object to hold the frames that are to be checked against """ - def process_folder(self) -> Generator[Dict[str, str], None, None]: + def process_folder(self) -> Generator[dict[str, str], None, None]: """ Iterate through the frames folder pulling the base filename Yields @@ -444,7 +443,7 @@ class Frames(MediaLoader): for item in iterator(): yield item - def process_frames(self) -> Generator[Dict[str, str], None, None]: + def process_frames(self) -> Generator[dict[str, str], None, None]: """ Process exported Frames Yields @@ -465,7 +464,7 @@ class Frames(MediaLoader): logger.trace(retval) # type: ignore yield retval - def process_video(self) -> Generator[Dict[str, str], None, None]: + def process_video(self) -> Generator[dict[str, str], None, None]: """Dummy in frames for video Yields @@ -485,7 +484,7 @@ class Frames(MediaLoader): logger.trace(retval) # type: ignore yield retval - def load_items(self) -> Dict[str, Tuple[str, str]]: + def load_items(self) -> dict[str, tuple[str, str]]: """ Load the frame info into dictionary Returns @@ -493,14 +492,14 @@ class Frames(MediaLoader): dict Fullname as key, tuple of frame name and extension as value """ - frames: Dict[str, Tuple[str, str]] = {} - for frame in cast(List[Dict[str, str]], self.file_list_sorted): + frames: dict[str, tuple[str, str]] = {} + for frame in T.cast(list[dict[str, str]], self.file_list_sorted): frames[frame["frame_fullname"]] = (frame["frame_name"], frame["frame_extension"]) logger.trace(frames) # type: ignore return frames - def sorted_items(self) -> List[Dict[str, str]]: + def sorted_items(self) -> list[dict[str, str]]: """ Return the items sorted by filename Returns @@ -532,11 +531,11 @@ class ExtractedFaces(): self.padding = int(size * 0.1875) self.alignments = alignments self.frames = frames - self.current_frame: Optional[str] = None - self.faces: List[DetectedFace] = [] + self.current_frame: str | None = None + self.faces: list[DetectedFace] = [] logger.trace("Initialized %s", self.__class__.__name__) # type: ignore - def get_faces(self, frame: str, image: Optional["np.ndarray"] = None) -> None: + def get_faces(self, frame: str, image: np.ndarray | None = None) -> None: """ Obtain faces and transformed landmarks for each face in a given frame with its alignments @@ -561,8 +560,8 @@ class ExtractedFaces(): self.current_frame = frame def extract_one_face(self, - alignment: "AlignmentFileDict", - image: "np.ndarray") -> DetectedFace: + alignment: AlignmentFileDict, + image: np.ndarray) -> DetectedFace: """ Extract one face from image Parameters @@ -588,7 +587,7 @@ class ExtractedFaces(): def get_faces_in_frame(self, frame: str, update: bool = False, - image: Optional["np.ndarray"] = None) -> List[DetectedFace]: + image: np.ndarray | None = None) -> list[DetectedFace]: """ Return the faces for the selected frame Parameters @@ -613,7 +612,7 @@ class ExtractedFaces(): self.get_faces(frame, image=image) return self.faces - def get_roi_size_for_frame(self, frame: str) -> List[int]: + def get_roi_size_for_frame(self, frame: str) -> list[int]: """ Return the size of the original extract box for the selected frame. Parameters diff --git a/tools/mask/mask.py b/tools/mask/mask.py index 97ae49b2..3153142f 100644 --- a/tools/mask/mask.py +++ b/tools/mask/mask.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 """ Tool to generate masks and previews of masks for existing alignments file """ +from __future__ import annotations import logging import os import sys +import typing as T + from argparse import Namespace from multiprocessing import Process -from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 import numpy as np @@ -18,7 +20,7 @@ from lib.multithreading import MultiThread from lib.utils import get_folder, _video_extensions from plugins.extract.pipeline import Extractor, ExtractMedia -if TYPE_CHECKING: +if T.TYPE_CHECKING: from lib.align.aligned_face import CenteringType from lib.align.alignments import AlignmentFileDict, PNGHeaderDict from lib.queue_manager import EventQueue @@ -45,7 +47,7 @@ class Mask(): # pylint:disable=too-few-public-methods self._args = arguments self._input_locations = self._get_input_locations() - def _get_input_locations(self) -> List[str]: + def _get_input_locations(self) -> list[str]: """ Obtain the full path to input locations. Will be a list of locations if batch mode is selected, or containing a single location if batch mode is not selected. @@ -141,18 +143,18 @@ class _Mask(): # pylint:disable=too-few-public-methods self._update_type = arguments.processing self._input_is_faces = arguments.input_type == "faces" self._mask_type = arguments.masker - self._output = dict(opts=dict(blur_kernel=arguments.blur_kernel, - threshold=arguments.threshold), - type=arguments.output_type, - full_frame=arguments.full_frame, - suffix=self._get_output_suffix(arguments)) - self._counts = dict(face=0, skip=0, update=0) + self._output = {"opts": {"blur_kernel": arguments.blur_kernel, + "threshold": arguments.threshold}, + "type": arguments.output_type, + "full_frame": arguments.full_frame, + "suffix": self._get_output_suffix(arguments)} + self._counts = {"face": 0, "skip": 0, "update": 0} self._check_input(arguments.input) self._saver = self._set_saver(arguments) loader = FacesLoader if self._input_is_faces else ImagesLoader self._loader = loader(arguments.input) - self._faces_saver: Optional[ImagesSaver] = None + self._faces_saver: ImagesSaver | None = None self._alignments = self._get_alignments(arguments) self._extractor = self._get_extractor(arguments.exclude_gpus) @@ -178,7 +180,7 @@ class _Mask(): # pylint:disable=too-few-public-methods sys.exit(0) logger.debug("input '%s' is valid", mask_input) - def _set_saver(self, arguments: Namespace) -> Optional[ImagesSaver]: + def _set_saver(self, arguments: Namespace) -> ImagesSaver | None: """ set the saver in a background thread Parameters @@ -204,7 +206,7 @@ class _Mask(): # pylint:disable=too-few-public-methods logger.debug(saver) return saver - def _get_alignments(self, arguments: Namespace) -> Optional[Alignments]: + def _get_alignments(self, arguments: Namespace) -> Alignments | None: """ Obtain the alignments from either the given alignments location or the default location. @@ -242,7 +244,7 @@ class _Mask(): # pylint:disable=too-few-public-methods return Alignments(folder, filename=filename) - def _get_extractor(self, exclude_gpus: List[int]) -> Optional[Extractor]: + def _get_extractor(self, exclude_gpus: list[int]) -> Extractor | None: """ Obtain a Mask extractor plugin and launch it Parameters ---------- @@ -303,7 +305,7 @@ class _Mask(): # pylint:disable=too-few-public-methods def _process_face(self, filename: str, image: np.ndarray, - metadata: "PNGHeaderDict") -> Optional["ExtractMedia"]: + metadata: PNGHeaderDict) -> ExtractMedia | None: """ Process a single face when masking from face images filename: str @@ -324,7 +326,7 @@ class _Mask(): # pylint:disable=too-few-public-methods if self._alignments is None: # mask from PNG header lookup_index = 0 - alignments = [cast("AlignmentFileDict", metadata["alignments"])] + alignments = [T.cast("AlignmentFileDict", metadata["alignments"])] else: # mask from Alignments file lookup_index = face_index alignments = self._alignments.get_faces_in_frame(frame_name) @@ -350,7 +352,7 @@ class _Mask(): # pylint:disable=too-few-public-methods self._counts["update"] += 1 return media - def _input_faces(self, *args: Union[tuple, Tuple["EventQueue"]]) -> None: + def _input_faces(self, *args: tuple | tuple[EventQueue]) -> None: """ Input pre-aligned faces to the Extractor plugin inside a thread Parameters @@ -362,7 +364,7 @@ class _Mask(): # pylint:disable=too-few-public-methods log_once = False logger.debug("args: %s", args) if self._update_type != "output": - queue = cast("EventQueue", args[0]) + queue = T.cast("EventQueue", args[0]) for filename, image, metadata in tqdm(self._loader.load(), total=self._loader.count): if not metadata: # Legacy faces. Update the headers if self._alignments is None: @@ -394,7 +396,7 @@ class _Mask(): # pylint:disable=too-few-public-methods if self._update_type != "output": queue.put("EOF") - def _input_frames(self, *args: Union[tuple, Tuple["EventQueue"]]) -> None: + def _input_frames(self, *args: tuple | tuple[EventQueue]) -> None: """ Input frames to the Extractor plugin inside a thread Parameters @@ -406,7 +408,7 @@ class _Mask(): # pylint:disable=too-few-public-methods assert self._alignments is not None logger.debug("args: %s", args) if self._update_type != "output": - queue = cast("EventQueue", args[0]) + queue = T.cast("EventQueue", args[0]) for filename, image in tqdm(self._loader.load(), total=self._loader.count): frame = os.path.basename(filename) if not self._alignments.frame_exists(frame): @@ -438,7 +440,7 @@ class _Mask(): # pylint:disable=too-few-public-methods if self._update_type != "output": queue.put("EOF") - def _check_for_missing(self, frame: str, idx: int, alignment: "AlignmentFileDict") -> bool: + def _check_for_missing(self, frame: str, idx: int, alignment: AlignmentFileDict) -> bool: """ Check if the alignment is missing the requested mask_type Parameters @@ -482,7 +484,7 @@ class _Mask(): # pylint:disable=too-few-public-methods return sfx @classmethod - def _get_detected_face(cls, alignment: "AlignmentFileDict") -> DetectedFace: + def _get_detected_face(cls, alignment: AlignmentFileDict) -> DetectedFace: """ Convert an alignment dict item to a detected_face object Parameters @@ -554,8 +556,8 @@ class _Mask(): # pylint:disable=too-few-public-methods if self._alignments is not None: self._alignments.update_face(frame_name, face_index, face.to_alignment()) - metadata: "PNGHeaderDict" = dict(alignments=face.to_png_meta(), - source=extractor_output.frame_metadata) + metadata: PNGHeaderDict = {"alignments": face.to_png_meta(), + "source": extractor_output.frame_metadata} self._faces_saver.save(extractor_output.filename, encode_image(extractor_output.image, ".png", metadata=metadata)) @@ -645,9 +647,9 @@ class _Mask(): # pylint:disable=too-few-public-methods size=detected_face.image.shape[0], is_aligned=True).face else: - centering: "CenteringType" = ("legacy" if self._alignments is not None and - self._alignments.version == 1.0 - else mask.stored_centering) + centering: CenteringType = ("legacy" if self._alignments is not None and + self._alignments.version == 1.0 + else mask.stored_centering) detected_face.load_aligned(detected_face.image, centering=centering, force=True) face = detected_face.aligned.face assert face is not None diff --git a/tools/model/cli.py b/tools/model/cli.py index 939d6334..21117df5 100644 --- a/tools/model/cli.py +++ b/tools/model/cli.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ Command Line Arguments for tools """ import gettext -from typing import Any, List, Dict +import typing as T from lib.cli.args import FaceSwapArgs from lib.cli.actions import DirFullPaths, Radio @@ -22,7 +22,7 @@ class ModelArgs(FaceSwapArgs): return _("A tool for performing actions on Faceswap trained model files") @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Put the arguments in a list so that they are accessible from both argparse and gui """ argument_list = [] argument_list.append(dict( diff --git a/tools/model/model.py b/tools/model/model.py index f8ec6da2..9bdac285 100644 --- a/tools/model/model.py +++ b/tools/model/model.py @@ -118,7 +118,7 @@ class Inference(): # pylint:disable=too-few-public-methods self._format = arguments.format self._input_file, self._output_file = self._get_output_file(arguments.model_dir) - def _get_output_file(self, model_dir: str) -> T.Tuple[str, str]: + def _get_output_file(self, model_dir: str) -> tuple[str, str]: """ Obtain the full path for the output model file/folder Parameters @@ -183,7 +183,7 @@ class NaNScan(): # pylint:disable=too-few-public-methods return os.path.join(model_dir, model_file) def _parse_weights(self, - layer: T.Union[keras.models.Model, keras.layers.Layer]) -> dict: + layer: keras.models.Model | keras.layers.Layer) -> dict: """ Recursively pass through sub-models to scan layer weights""" weights = layer.get_weights() logger.debug("Processing weights for layer '%s', length: '%s'", diff --git a/tools/preview/cli.py b/tools/preview/cli.py index 7c324751..da327028 100644 --- a/tools/preview/cli.py +++ b/tools/preview/cli.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ Command Line Arguments for tools """ import gettext -from typing import Any, List, Dict +import typing as T from lib.cli.args import FaceSwapArgs from lib.cli.actions import DirOrFileFullPaths, DirFullPaths, FileFullPaths @@ -29,7 +29,7 @@ class PreviewArgs(FaceSwapArgs): return _("Preview tool\nAllows you to configure your convert settings with a live preview") @staticmethod - def get_argument_list() -> List[Dict[str, Any]]: + def get_argument_list() -> list[dict[str, T.Any]]: """ Put the arguments in a list so that they are accessible from both argparse and gui Returns diff --git a/tools/preview/control_panels.py b/tools/preview/control_panels.py index 9b214bfe..f7379cca 100644 --- a/tools/preview/control_panels.py +++ b/tools/preview/control_panels.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 """ Manages the widgets that hold the bottom 'control' area of the preview tool """ +from __future__ import annotations import gettext import logging +import typing as T + import tkinter as tk from tkinter import ttk from configparser import ConfigParser -from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union from lib.gui.custom_widgets import Tooltip from lib.gui.control_helper import ControlPanel, ControlPanelOption @@ -14,7 +16,8 @@ from lib.gui.utils import get_images from plugins.plugin_loader import PluginLoader from plugins.convert._config import Config -if TYPE_CHECKING: +if T.TYPE_CHECKING: + from collections.abc import Callable from .preview import Preview logger = logging.getLogger(__name__) @@ -34,10 +37,8 @@ class ConfigTools(): """ def __init__(self) -> None: self._config = Config(None) - self.tk_vars: Dict[str, Dict[str, Union[tk.BooleanVar, - tk.StringVar, - tk.IntVar, - tk.DoubleVar]]] = {} + self.tk_vars: dict[str, dict[str, tk.BooleanVar | tk.StringVar | tk.IntVar | tk.DoubleVar] + ] = {} self._config_dicts = self._get_config_dicts() # Holds currently saved config @property @@ -46,18 +47,18 @@ class ConfigTools(): return self._config @property - def config_dicts(self) -> Dict[str, Any]: + def config_dicts(self) -> dict[str, T.Any]: """ dict: The convert configuration options in dictionary form.""" return self._config_dicts @property - def sections(self) -> List[str]: + def sections(self) -> list[str]: """ list: The sorted section names that exist within the convert Configuration options. """ return sorted(set(plugin.split(".")[0] for plugin in self._config.config.sections() if plugin.split(".")[0] != "writer")) @property - def plugins_dict(self) -> Dict[str, List[str]]: + def plugins_dict(self) -> dict[str, list[str]]: """ dict: Dictionary of configuration option sections as key with a list of containing plugins as the value """ return {section: sorted([plugin.split(".")[1] for plugin in self._config.config.sections() @@ -81,7 +82,7 @@ class ConfigTools(): section, item, old_value, new_value) self._config.config[section][item] = new_value - def _get_config_dicts(self) -> Dict[str, Dict[str, Any]]: + def _get_config_dicts(self) -> dict[str, dict[str, T.Any]]: """ Obtain a custom configuration dictionary for convert configuration items in use by the preview tool formatted for control helper. @@ -91,7 +92,7 @@ class ConfigTools(): Each configuration section as keys, with the values as a dict of option: :class:`lib.gui.control_helper.ControlOption` pairs. """ logger.debug("Formatting Config for GUI") - config_dicts: Dict[str, Dict[str, Any]] = {} + config_dicts: dict[str, dict[str, T.Any]] = {} for section in self._config.config.sections(): if section.startswith("writer."): continue @@ -114,7 +115,7 @@ class ConfigTools(): logger.debug("Formatted Config for GUI: %s", config_dicts) return config_dicts - def reset_config_to_saved(self, section: Optional[str] = None) -> None: + def reset_config_to_saved(self, section: str | None = None) -> None: """ Reset the GUI parameters to their saved values within the configuration file. Parameters @@ -135,7 +136,7 @@ class ConfigTools(): logger.debug("Setting %s - %s to saved value %s", config_section, item, val) logger.debug("Reset to saved config: %s", section) - def reset_config_to_default(self, section: Optional[str] = None) -> None: + def reset_config_to_default(self, section: str | None = None) -> None: """ Reset the GUI parameters to their default configuration values. Parameters @@ -157,7 +158,7 @@ class ConfigTools(): config_section, item, default) logger.debug("Reset to default: %s", section) - def save_config(self, section: Optional[str] = None) -> None: + def save_config(self, section: str | None = None) -> None: """ Save the configuration ``.ini`` file with the currently stored values. Notes @@ -258,18 +259,18 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors parent: tkinter object The parent tkinter object that holds the Action Frame """ - def __init__(self, app: 'Preview', parent: ttk.Frame) -> None: + def __init__(self, app: Preview, parent: ttk.Frame) -> None: logger.debug("Initializing %s: (app: %s, parent: %s)", self.__class__.__name__, app, parent) self._app = app super().__init__(parent) self.pack(side=tk.LEFT, anchor=tk.N, fill=tk.Y) - self._tk_vars: Dict[str, tk.StringVar] = {} + self._tk_vars: dict[str, tk.StringVar] = {} - self._options = dict( - color=app._patch.converter.cli_arguments.color_adjustment.replace("-", "_"), - mask_type=app._patch.converter.cli_arguments.mask_type.replace("-", "_")) + self._options = { + "color": app._patch.converter.cli_arguments.color_adjustment.replace("-", "_"), + "mask_type": app._patch.converter.cli_arguments.mask_type.replace("-", "_")} defaults = {opt: self._format_to_display(val) for opt, val in self._options.items()} self._busy_bar = self._build_frame(defaults, @@ -279,7 +280,7 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors app._samples.predictor.has_predicted_mask) @property - def convert_args(self) -> Dict[str, Any]: + def convert_args(self) -> dict[str, T.Any]: """ dict: Currently selected Command line arguments from the :class:`ActionFrame`. """ return {opt if opt != "color" else "color_adjustment": self._format_from_display(self._tk_vars[opt].get()) @@ -323,10 +324,10 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors return var.replace("_", " ").replace("-", " ").title() def _build_frame(self, - defaults: Dict[str, Any], + defaults: dict[str, T.Any], refresh_callback: Callable[[], None], patch_callback: Callable[[], None], - available_masks: List[str], + available_masks: list[str], has_predicted_mask: bool) -> BusyProgressBar: """ Build the :class:`ActionFrame`. @@ -366,8 +367,8 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors def _add_cli_choices(self, parent: ttk.Frame, - defaults: Dict[str, Any], - available_masks: List[str], + defaults: dict[str, T.Any], + available_masks: list[str], has_predicted_mask: bool) -> None: """ Create :class:`lib.gui.control_helper.ControlPanel` object for the command line options. @@ -382,13 +383,13 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors Whether the model was trained with a mask """ cp_options = self._get_control_panel_options(defaults, available_masks, has_predicted_mask) - panel_kwargs = dict(blank_nones=False, label_width=10, style="CPanel") + panel_kwargs = {"blank_nones": False, "label_width": 10, "style": "CPanel"} ControlPanel(parent, cp_options, header_text=None, **panel_kwargs) def _get_control_panel_options(self, - defaults: Dict[str, Any], - available_masks: List[str], - has_predicted_mask: bool) -> List[ControlPanelOption]: + defaults: dict[str, T.Any], + available_masks: list[str], + has_predicted_mask: bool) -> list[ControlPanelOption]: """ Create :class:`lib.gui.control_helper.ControlPanelOption` objects for the command line options. @@ -404,7 +405,7 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors list The list of `lib.gui.control_helper.ControlPanelOption` objects for the Action Frame """ - cp_options: List[ControlPanelOption] = [] + cp_options: list[ControlPanelOption] = [] for opt in self._options: if opt == "mask_type": choices = self._create_mask_choices(defaults, available_masks, has_predicted_mask) @@ -422,9 +423,9 @@ class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors return cp_options def _create_mask_choices(self, - defaults: Dict[str, Any], - available_masks: List[str], - has_predicted_mask: bool) -> List[str]: + defaults: dict[str, T.Any], + available_masks: list[str], + has_predicted_mask: bool) -> list[str]: """ Set the mask choices and default mask based on available masks. Parameters @@ -537,7 +538,7 @@ class OptionsBook(ttk.Notebook): # pylint:disable=too-many-ancestors self.pack(side=tk.RIGHT, anchor=tk.N, fill=tk.BOTH, expand=True) self.config_tools = config_tools - self._tabs: Dict[str, Dict[str, Union[ttk.Notebook, ConfigFrame]]] = {} + self._tabs: dict[str, dict[str, ttk.Notebook | ConfigFrame]] = {} self._build_tabs() self._build_sub_tabs() self._add_patch_callback(patch_callback) @@ -560,7 +561,7 @@ class OptionsBook(ttk.Notebook): # pylint:disable=too-many-ancestors tab = ConfigFrame(self, config_key, config_dict) self._tabs[section][plugin] = tab text = plugin.replace("_", " ").title() - cast(ttk.Notebook, self._tabs[section]["tab"]).add(tab, text=text) + T.cast(ttk.Notebook, self._tabs[section]["tab"]).add(tab, text=text) def _add_patch_callback(self, patch_callback: Callable[[], None]) -> None: """ Add callback to re-patch images on configuration option change. @@ -591,7 +592,7 @@ class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors def __init__(self, parent: OptionsBook, config_key: str, - options: Dict[str, Any]): + options: dict[str, T.Any]): logger.debug("Initializing %s", self.__class__.__name__) super().__init__(parent) self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) @@ -616,7 +617,7 @@ class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors The section/plugin key for these configuration options """ logger.debug("Add Config Frame") - panel_kwargs = dict(columns=2, option_columns=2, blank_nones=False, style="CPanel") + panel_kwargs = {"columns": 2, "option_columns": 2, "blank_nones": False, "style": "CPanel"} frame = ttk.Frame(self) frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) cp_options = [opt for key, opt in self._options.items() if key != "helptext"] diff --git a/tools/preview/preview.py b/tools/preview/preview.py index f6ab1663..cdacc1ae 100644 --- a/tools/preview/preview.py +++ b/tools/preview/preview.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 """ Tool to preview swaps and tweak configuration prior to running a convert """ - +from __future__ import annotations import gettext import logging import random import tkinter as tk +import typing as T + from tkinter import ttk -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import os import sys - from threading import Event, Lock, Thread import numpy as np @@ -29,13 +29,7 @@ from plugins.extract.pipeline import ExtractMedia from .control_panels import ActionFrame, ConfigTools, OptionsBook from .viewer import FacesDisplay, ImagesCanvas - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from argparse import Namespace from lib.queue_manager import EventQueue from .control_panels import BusyProgressBar @@ -62,7 +56,7 @@ class Preview(tk.Tk): # pylint:disable=too-few-public-methods """ _w: str - def __init__(self, arguments: "Namespace") -> None: + def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s: (arguments: '%s'", self.__class__.__name__, arguments) super().__init__() self._config_tools = ConfigTools() @@ -73,9 +67,9 @@ class Preview(tk.Tk): # pylint:disable=too-few-public-methods self._patch = Patch(self, arguments) self._initialize_tkinter() - self._image_canvas: Optional[ImagesCanvas] = None - self._opts_book: Optional[OptionsBook] = None - self._cli_frame: Optional[ActionFrame] = None # cli frame holds cli options + self._image_canvas: ImagesCanvas | None = None + self._opts_book: OptionsBook | None = None + self._cli_frame: ActionFrame | None = None # cli frame holds cli options logger.debug("Initialized %s", self.__class__.__name__) @property @@ -102,7 +96,7 @@ class Preview(tk.Tk): # pylint:disable=too-few-public-methods return self._lock @property - def progress_bar(self) -> "BusyProgressBar": + def progress_bar(self) -> BusyProgressBar: """ :class:`~tools.preview.control_panels.BusyProgressBar`: The progress bar that indicates a swap/patch thread is running """ assert self._cli_frame is not None @@ -278,13 +272,13 @@ class Samples(): The number of samples to take from the input video/images """ - def __init__(self, app: Preview, arguments: "Namespace", sample_size: int) -> None: + def __init__(self, app: Preview, arguments: Namespace, sample_size: int) -> None: logger.debug("Initializing %s: (app: %s, arguments: '%s', sample_size: %s)", self.__class__.__name__, app, arguments, sample_size) self._sample_size = sample_size self._app = app - self._input_images: List[ConvertItem] = [] - self._predicted_images: List[Tuple[ConvertItem, np.ndarray]] = [] + self._input_images: list[ConvertItem] = [] + self._predicted_images: list[tuple[ConvertItem, np.ndarray]] = [] self._images = Images(arguments) self._alignments = Alignments(arguments, @@ -310,7 +304,7 @@ class Samples(): logger.debug("Initialized %s", self.__class__.__name__) @property - def available_masks(self) -> List[str]: + def available_masks(self) -> list[str]: """ list: The mask names that are available for every face in the alignments file """ retval = [key for key, val in self.alignments.mask_summary.items() @@ -323,7 +317,7 @@ class Samples(): return self._sample_size @property - def predicted_images(self) -> List[Tuple[ConvertItem, np.ndarray]]: + def predicted_images(self) -> list[tuple[ConvertItem, np.ndarray]]: """ list: The predicted faces output from the Faceswap model """ return self._predicted_images @@ -338,13 +332,13 @@ class Samples(): return self._predictor @property - def _random_choice(self) -> List[int]: + def _random_choice(self) -> list[int]: """ list: Random indices from the :attr:`_indices` group """ retval = [random.choice(indices) for indices in self._indices] logger.debug(retval) return retval - def _get_filelist(self) -> List[str]: + def _get_filelist(self) -> list[str]: """ Get a list of files for the input, filtering out those frames which do not contain faces. @@ -372,7 +366,7 @@ class Samples(): raise FaceswapError(msg) from err return retval - def _get_indices(self) -> List[List[int]]: + def _get_indices(self) -> list[list[int]]: """ Get indices for each sample group. Obtain :attr:`self.sample_size` evenly sized groups of indices @@ -450,9 +444,8 @@ class Samples(): idx = 0 while idx < self._sample_size: logger.debug("Predicting face %s of %s", idx + 1, self._sample_size) - items: Union[Literal["EOF"], - List[Tuple[ConvertItem, - np.ndarray]]] = self._predictor.out_queue.get() + items: (T.Literal["EOF"] | + list[tuple[ConvertItem, np.ndarray]]) = self._predictor.out_queue.get() if items == "EOF": logger.debug("Received EOF") break @@ -481,12 +474,12 @@ class Patch(): # pylint:disable=too-few-public-methods converter_arguments: dict The currently selected converter command line arguments for the patch queue """ - def __init__(self, app: Preview, arguments: "Namespace") -> None: + def __init__(self, app: Preview, arguments: Namespace) -> None: logger.debug("Initializing %s: (app: %s, arguments: '%s')", self.__class__.__name__, app, arguments) self._app = app self._queue_patch_in = queue_manager.get_queue("preview_patch_in") - self.converter_arguments: Optional[Dict[str, Any]] = None # Updated converter args dict + self.converter_arguments: dict[str, T.Any] | None = None # Updated converter args configfile = arguments.configfile if hasattr(arguments, "configfile") else None self._converter = Converter(output_size=app._samples.predictor.output_size, @@ -513,8 +506,8 @@ class Patch(): # pylint:disable=too-few-public-methods return self._converter @staticmethod - def _generate_converter_arguments(arguments: "Namespace", - available_masks: List[str]) -> "Namespace": + def _generate_converter_arguments(arguments: Namespace, + available_masks: list[str]) -> Namespace: """ Add the default converter arguments to the initial arguments. Ensure the mask selection is available. @@ -550,7 +543,7 @@ class Patch(): # pylint:disable=too-few-public-methods return arguments def _process(self, - patch_queue_in: "EventQueue", + patch_queue_in: EventQueue, trigger_event: Event, samples: Samples) -> None: """ The face patching process. @@ -601,7 +594,7 @@ class Patch(): # pylint:disable=too-few-public-methods logger.debug("Updated Converter cli arguments") @staticmethod - def _feed_swapped_faces(patch_queue_in: "EventQueue", samples: Samples) -> None: + def _feed_swapped_faces(patch_queue_in: EventQueue, samples: Samples) -> None: """ Feed swapped faces to the converter's in-queue. Parameters @@ -620,9 +613,9 @@ class Patch(): # pylint:disable=too-few-public-methods patch_queue_in.put("EOF") def _patch_faces(self, - queue_in: "EventQueue", - queue_out: "EventQueue", - sample_size: int) -> List[np.ndarray]: + queue_in: EventQueue, + queue_out: EventQueue, + sample_size: int) -> list[np.ndarray]: """ Patch faces. Run the convert process on the swapped faces and return the patched faces. diff --git a/tools/preview/viewer.py b/tools/preview/viewer.py index b1876030..7abe11b9 100644 --- a/tools/preview/viewer.py +++ b/tools/preview/viewer.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 """ Manages the widgets that hold the top 'viewer' area of the preview tool """ +from __future__ import annotations import logging import os import tkinter as tk -from tkinter import ttk +import typing as T +from tkinter import ttk from dataclasses import dataclass, field -from typing import cast, List, Optional, Tuple, TYPE_CHECKING import cv2 import numpy as np @@ -17,7 +18,7 @@ from lib.align.aligned_face import CenteringType from scripts.convert import ConvertItem -if TYPE_CHECKING: +if T.TYPE_CHECKING: from .preview import Preview logger = logging.getLogger(__name__) @@ -26,10 +27,10 @@ logger = logging.getLogger(__name__) @dataclass class _Faces: """ Dataclass for holding faces """ - filenames: List[str] = field(default_factory=list) - matrix: List[np.ndarray] = field(default_factory=list) - src: List[np.ndarray] = field(default_factory=list) - dst: List[np.ndarray] = field(default_factory=list) + filenames: list[str] = field(default_factory=list) + matrix: list[np.ndarray] = field(default_factory=list) + src: list[np.ndarray] = field(default_factory=list) + dst: list[np.ndarray] = field(default_factory=list) class FacesDisplay(): @@ -55,7 +56,7 @@ class FacesDisplay(): The list of :class:`numpy.ndarray` swapped and patched preview images for bottom row of display """ - def __init__(self, app: 'Preview', size: int, padding: int) -> None: + def __init__(self, app: Preview, size: int, padding: int) -> None: logger.trace("Initializing %s: (app: %s, size: %s, padding: %s)", # type: ignore self.__class__.__name__, app, size, padding) self._size = size @@ -64,21 +65,21 @@ class FacesDisplay(): self._padding = padding self._faces = _Faces() - self._centering: Optional[CenteringType] = None + self._centering: CenteringType | None = None self._faces_source: np.ndarray = np.array([]) self._faces_dest: np.ndarray = np.array([]) - self._tk_image: Optional[ImageTk.PhotoImage] = None + self._tk_image: ImageTk.PhotoImage | None = None # Set from Samples self.update_source = False - self.source: List[ConvertItem] = [] # Source images, filenames + detected faces + self.source: list[ConvertItem] = [] # Source images, filenames + detected faces # Set from Patch - self.destination: List[np.ndarray] = [] # Swapped + patched images + self.destination: list[np.ndarray] = [] # Swapped + patched images logger.trace("Initialized %s", self.__class__.__name__) # type: ignore @property - def tk_image(self) -> Optional[ImageTk.PhotoImage]: + def tk_image(self) -> ImageTk.PhotoImage | None: """ :class:`PIL.ImageTk.PhotoImage`: The compiled preview display in tkinter display format """ return self._tk_image @@ -99,7 +100,7 @@ class FacesDisplay(): """ self._centering = centering - def set_display_dimensions(self, dimensions: Tuple[int, int]) -> None: + def set_display_dimensions(self, dimensions: tuple[int, int]) -> None: """ Adjust the size of the frame that will hold the preview samples. Parameters @@ -121,7 +122,7 @@ class FacesDisplay(): self._tk_image = ImageTk.PhotoImage(pilimg) logger.trace("Updated tk image") # type: ignore - def _get_scale_size(self, image: np.ndarray) -> Tuple[int, int]: + def _get_scale_size(self, image: np.ndarray) -> tuple[int, int]: """ Get the size that the full preview image should be resized to fit in the display window. @@ -180,7 +181,7 @@ class FacesDisplay(): src_img = item.inbound.image detected_face.load_aligned(src_img, size=self._size, - centering=cast(CenteringType, self._centering)) + centering=T.cast(CenteringType, self._centering)) matrix = detected_face.aligned.matrix self._faces.filenames.append(os.path.splitext(item.inbound.filename)[0]) self._faces.matrix.append(matrix) @@ -265,7 +266,7 @@ class ImagesCanvas(ttk.Frame): # pylint:disable=too-many-ancestors parent: tkinter object The parent tkinter object that holds the canvas """ - def __init__(self, app: 'Preview', parent: ttk.PanedWindow) -> None: + def __init__(self, app: Preview, parent: ttk.PanedWindow) -> None: logger.debug("Initializing %s: (app: %s, parent: %s)", self.__class__.__name__, app, parent) super().__init__(parent) diff --git a/tools/sort/sort.py b/tools/sort/sort.py index 71510045..17d14176 100644 --- a/tools/sort/sort.py +++ b/tools/sort/sort.py @@ -2,14 +2,14 @@ """ A tool that allows for sorting and grouping images in different ways. """ +from __future__ import annotations import logging import os import sys - +import typing as T from argparse import Namespace from shutil import copyfile, rmtree -from typing import Dict, List, Optional, TYPE_CHECKING from tqdm import tqdm @@ -20,7 +20,7 @@ from lib.utils import deprecation_warning from .sort_methods import SortBlur, SortColor, SortFace, SortHistogram, SortMultiMethod from .sort_methods_aligned import SortDistance, SortFaceCNN, SortPitch, SortSize, SortYaw, SortRoll -if TYPE_CHECKING: +if T.TYPE_CHECKING: from .sort_methods import SortMethod logger = logging.getLogger(__name__) @@ -65,7 +65,7 @@ class Sort(): # pylint:disable=too-few-public-methods self._args.sort_method = "color-black" if sort_ == "black-pixels" else sort_ self._args.group_method = "color-black" if group_ == "black-pixels" else group_ - def _get_input_locations(self) -> List[str]: + def _get_input_locations(self) -> list[str]: """ Obtain the full path to input locations. Will be a list of locations if batch mode is selected, or a containing a single location if batch mode is not selected. @@ -127,27 +127,27 @@ class _Sort(): # pylint:disable=too-few-public-methods """ Sorts folders of faces based on input criteria """ def __init__(self, arguments: Namespace) -> None: logger.debug("Initializing %s: arguments: %s", self.__class__.__name__, arguments) - self._processes = dict(blur=SortBlur, - blur_fft=SortBlur, - distance=SortDistance, - yaw=SortYaw, - pitch=SortPitch, - roll=SortRoll, - size=SortSize, - face=SortFace, - face_cnn=SortFaceCNN, - face_cnn_dissim=SortFaceCNN, - hist=SortHistogram, - hist_dissim=SortHistogram, - color_black=SortColor, - color_gray=SortColor, - color_luma=SortColor, - color_green=SortColor, - color_orange=SortColor) + self._processes = {"blur": SortBlur, + "blur_fft": SortBlur, + "distance": SortDistance, + "yaw": SortYaw, + "pitch": SortPitch, + "roll": SortRoll, + "size": SortSize, + "face": SortFace, + "face_cnn": SortFaceCNN, + "face_cnn_dissim": SortFaceCNN, + "hist": SortHistogram, + "hist_dissim": SortHistogram, + "color_black": SortColor, + "color_gray": SortColor, + "color_luma": SortColor, + "color_green": SortColor, + "color_orange": SortColor} self._args = self._parse_arguments(arguments) - self._changes: Dict[str, str] = {} - self.serializer: Optional[Serializer] = None + self._changes: dict[str, str] = {} + self.serializer: Serializer | None = None if arguments.log_changes: self.serializer = get_serializer_from_filename(arguments.log_file_path) @@ -220,7 +220,7 @@ class _Sort(): # pylint:disable=too-few-public-methods logger.debug("Cleaned arguments: %s", arguments) return arguments - def _get_sorter(self) -> "SortMethod": + def _get_sorter(self) -> SortMethod: """ Obtain a sorter/grouper combo for the selected sort/group by options Returns diff --git a/tools/sort/sort_methods.py b/tools/sort/sort_methods.py index a2f7c2f1..507db4ae 100644 --- a/tools/sort/sort_methods.py +++ b/tools/sort/sort_methods.py @@ -4,11 +4,13 @@ All sorting methods inherit from :class:`SortMethod` and control functions for scorting one item, sorting a full list of scores and binning based on those sorted scores. """ +from __future__ import annotations import logging import operator import sys +import typing as T -from typing import Any, cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union +from collections.abc import Generator import cv2 import numpy as np @@ -19,21 +21,16 @@ from lib.image import FacesLoader, ImagesLoader, read_image_meta_batch, update_e from lib.utils import FaceswapError from plugins.extract.recognition.vgg_face2 import Cluster, Recognition as VGGFace -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if TYPE_CHECKING: +if T.TYPE_CHECKING: from argparse import Namespace from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderSourceDict logger = logging.getLogger(__name__) -ImgMetaType = Generator[Tuple[str, - Optional[np.ndarray], - Optional["PNGHeaderAlignmentsDict"]], None, None] +ImgMetaType: T.TypeAlias = Generator[tuple[str, + np.ndarray | None, + T.Union["PNGHeaderAlignmentsDict", None]], None, None] class InfoLoader(): @@ -50,14 +47,14 @@ class InfoLoader(): """ def __init__(self, input_dir: str, - info_type: Literal["face", "meta", "all"]) -> None: + info_type: T.Literal["face", "meta", "all"]) -> None: logger.debug("Initializing: %s (input_dir: %s, info_type: %s)", self.__class__.__name__, input_dir, info_type) self._info_type = info_type self._iterator = None self._description = "Reading image statistics..." self._loader = ImagesLoader(input_dir) if info_type == "face" else FacesLoader(input_dir) - self._cached_source_data: Dict[str, "PNGHeaderSourceDict"] = {} + self._cached_source_data: dict[str, PNGHeaderSourceDict] = {} if self._loader.count == 0: logger.error("No images to process in location: '%s'", input_dir) sys.exit(1) @@ -103,7 +100,7 @@ class InfoLoader(): def _get_alignments(self, filename: str, - metadata: Dict[str, Any]) -> Optional["PNGHeaderAlignmentsDict"]: + metadata: dict[str, T.Any]) -> PNGHeaderAlignmentsDict | None: """ Obtain the alignments from a PNG Header. The other image metadata is cached locally in case a sort method needs to write back to the @@ -182,7 +179,7 @@ class InfoLoader(): leave=False): yield filename, image, None - def update_png_header(self, filename: str, alignments: "PNGHeaderAlignmentsDict") -> None: + def update_png_header(self, filename: str, alignments: PNGHeaderAlignmentsDict) -> None: """ Update the PNG header of the given file with the given alignments. NB: Header information can only be updated if the face is already on at least alignment @@ -201,7 +198,7 @@ class InfoLoader(): return self._cached_source_data[filename]["alignments_version"] = 2.3 if vers == 2.2 else vers - header = dict(alignments=alignments, source=self._cached_source_data[filename]) + header = {"alignments": alignments, "source": self._cached_source_data[filename]} update_existing_metadata(filename, header) @@ -221,8 +218,8 @@ class SortMethod(): Default: ``False`` """ def __init__(self, - arguments: "Namespace", - loader_type: Literal["face", "meta", "all"] = "meta", + arguments: Namespace, + loader_type: T.Literal["face", "meta", "all"] = "meta", is_group: bool = False) -> None: logger.debug("Initializing %s: loader_type: '%s' is_group: %s, arguments: %s", self.__class__.__name__, loader_type, is_group, arguments) @@ -231,22 +228,22 @@ class SortMethod(): self._method = arguments.group_method if self._is_group else arguments.sort_method self._num_bins: int = arguments.num_bins - self._bin_names: List[str] = [] + self._bin_names: list[str] = [] self._loader_type = loader_type self._iterator = self._get_file_iterator(arguments.input_dir) - self._result: List[Tuple[str, Union[float, np.ndarray]]] = [] - self._binned: List[List[str]] = [] + self._result: list[tuple[str, float | np.ndarray]] = [] + self._binned: list[list[str]] = [] logger.debug("Initialized %s", self.__class__.__name__) @property - def loader_type(self) -> Literal["face", "meta", "all"]: + def loader_type(self) -> T.Literal["face", "meta", "all"]: """ ["face", "meta", "all"]: The loader that this sorter uses """ return self._loader_type @property - def binned(self) -> List[List[str]]: + def binned(self) -> list[list[str]]: """ list: List of bins (list) containing the filenames belonging to the bin. The binning process is called when this property is first accessed""" if not self._binned: @@ -255,7 +252,7 @@ class SortMethod(): return self._binned @property - def sorted_filelist(self) -> List[str]: + def sorted_filelist(self) -> list[str]: """ list: List of sorted filenames for given sorter in a single list. The sort process is called when this property is first accessed """ if not self._result: @@ -267,7 +264,7 @@ class SortMethod(): return retval @property - def bin_names(self) -> List[str]: + def bin_names(self) -> list[str]: """ list: The name of each created bin, if they exist, otherwise an empty list """ return self._bin_names @@ -305,7 +302,7 @@ class SortMethod(): [r[0] if isinstance(r, (tuple, list)) else r for r in self._result]) @classmethod - def _get_unique_labels(cls, numbers: np.ndarray) -> List[str]: + def _get_unique_labels(cls, numbers: np.ndarray) -> list[str]: """ For a list of threshold values for displaying in the bin name, get the lowest number of decimal figures (down to int) required to have a unique set of folder names and return the formatted numbers. @@ -338,7 +335,7 @@ class SortMethod(): logger.debug("rounded values: %s, formatted labels: %s", rounded, retval) return retval - def _binning_linear_threshold(self, units: str = "", multiplier: int = 1) -> List[List[str]]: + def _binning_linear_threshold(self, units: str = "", multiplier: int = 1) -> list[list[str]]: """ Standard linear binning method for binning by threshold. The minimum and maximum result from :attr:`_result` are taken, A range is created between @@ -367,7 +364,7 @@ class SortMethod(): f"{labels[idx]}{units}_to_{labels[idx + 1]}{units}" for idx in range(self._num_bins)] - bins: List[List[str]] = [[] for _ in range(self._num_bins)] + bins: list[list[str]] = [[] for _ in range(self._num_bins)] for filename, result in self._result: bin_idx = next(bin_id for bin_id, thresh in enumerate(thresholds) if result <= thresh) - 1 @@ -375,7 +372,7 @@ class SortMethod(): return bins - def _binning(self) -> List[List[str]]: + def _binning(self) -> list[list[str]]: """ Called when :attr:`binning` is first accessed. Checks if sorting has been done, if not triggers it, then does binning @@ -404,8 +401,8 @@ class SortMethod(): def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Override for sort method's specificic logic. This method should be executed to get a single score from a single image and add the result to :attr:`_result` @@ -420,7 +417,7 @@ class SortMethod(): """ raise NotImplementedError() - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Group into bins by their sorted score. Override for method specific binning techniques. Binning takes the results from :attr:`_result` compiled during :func:`_sort_filelist` and @@ -434,7 +431,7 @@ class SortMethod(): raise NotImplementedError() @classmethod - def _mask_face(cls, image: np.ndarray, alignments: "PNGHeaderAlignmentsDict") -> np.ndarray: + def _mask_face(cls, image: np.ndarray, alignments: PNGHeaderAlignmentsDict) -> np.ndarray: """ Function for applying the mask to an aligned face if both the face image and alignment data are available. @@ -481,7 +478,7 @@ class SortMultiMethod(SortMethod): A sort method object used for sorting and binning the images """ def __init__(self, - arguments: "Namespace", + arguments: Namespace, sort_method: SortMethod, group_method: SortMethod) -> None: self._sorter = sort_method @@ -515,8 +512,8 @@ class SortMultiMethod(SortMethod): def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Score a single image for sort method: "distance", "yaw" "pitch" or "size" and add the result to :attr:`_result` @@ -542,7 +539,7 @@ class SortMultiMethod(SortMethod): self._bin_names = self._grouper.bin_names logger.debug("Sorted") - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Override standard binning, to bin by the group-by method and sort by the sorting method. @@ -555,9 +552,9 @@ class SortMultiMethod(SortMethod): List of bins of filenames """ sorted_ = self._result - output: List[List[str]] = [] + output: list[list[str]] = [] for bin_ in tqdm(self._binned, desc="Binning and sorting", file=sys.stdout, leave=False): - indices: Dict[int, str] = {} + indices: dict[int, str] = {} for filename in bin_: indices[sorted_.index(filename)] = filename output.append([indices[idx] for idx in sorted(indices)]) @@ -575,7 +572,7 @@ class SortBlur(SortMethod): Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def __init__(self, arguments: "Namespace", is_group: bool = False) -> None: + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: super().__init__(arguments, loader_type="all", is_group=is_group) method = arguments.group_method if self._is_group else arguments.sort_method self._use_fft = method == "blur_fft" @@ -608,7 +605,7 @@ class SortBlur(SortMethod): def estimate_blur_fft(self, image: np.ndarray, - alignments: Optional["PNGHeaderAlignmentsDict"] = None) -> float: + alignments: PNGHeaderAlignmentsDict | None = None) -> float: """ Estimate the amount of blur a fft filtered image has. Parameters @@ -649,8 +646,8 @@ class SortBlur(SortMethod): def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Score a single image for blur or blur-fft and add the result to :attr:`_result` Parameters @@ -677,7 +674,7 @@ class SortBlur(SortMethod): logger.info("Sorting...") self._result = sorted(self._result, key=operator.itemgetter(1), reverse=True) - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Create bins to split linearly from the lowest to the highest sample value Returns @@ -699,7 +696,7 @@ class SortColor(SortMethod): Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def __init__(self, arguments: "Namespace", is_group: bool = False) -> None: + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: super().__init__(arguments, loader_type="face", is_group=is_group) self._desired_channel = {'gray': 0, 'luma': 0, 'orange': 1, 'green': 2} @@ -728,7 +725,7 @@ class SortColor(SortMethod): path = np.einsum_path(operation, image[..., :3], conversion, optimize='optimal')[0] return np.einsum(operation, image[..., :3], conversion, optimize=path).astype('float32') - def _near_split(self, bin_range: int) -> List[int]: + def _near_split(self, bin_range: int) -> list[int]: """ Obtain the split for the given number of bins for the given range Parameters @@ -750,14 +747,14 @@ class SortColor(SortMethod): uplimit += sep return bins - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Group into bins by percentage of black pixels """ # TODO. Only grouped by black pixels. Check color logger.info("Grouping by percentage of %s...", self._method) # Starting the binning process - bins: List[List[str]] = [[] for _ in range(self._num_bins)] + bins: list[list[str]] = [[] for _ in range(self._num_bins)] # Get edges of bins from 0 to 100 bins_edges = self._near_split(100) # Get the proper bin number for each img order @@ -772,8 +769,8 @@ class SortColor(SortMethod): def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Score a single image for color Parameters @@ -835,18 +832,18 @@ class SortFace(SortMethod): Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def __init__(self, arguments: "Namespace", is_group: bool = False) -> None: + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: super().__init__(arguments, loader_type="all", is_group=is_group) self._vgg_face = VGGFace(exclude_gpus=arguments.exclude_gpus) self._vgg_face.init_model() threshold = arguments.threshold self._output_update_info = True - self._threshold: Optional[float] = 0.25 if threshold < 0 else threshold + self._threshold: float | None = 0.25 if threshold < 0 else threshold def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Processing logic for sort by face method. Reads header information from the PNG file to look for VGGFace2 embedding. If it does not @@ -912,7 +909,7 @@ class SortFace(SortMethod): indices = Cluster(np.array(preds), "ward", threshold=self._threshold)() self._result = [(self._result[idx][0], float(score)) for idx, score in indices] - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Group into bins by their sorted score The bin ID has been output in the 2nd column of :attr:`_result` so use that for binnin @@ -924,7 +921,7 @@ class SortFace(SortMethod): """ num_bins = len(set(int(i[1]) for i in self._result)) logger.info("Grouping by %s...", self.__class__.__name__.replace("Sort", "")) - bins: List[List[str]] = [[] for _ in range(num_bins)] + bins: list[list[str]] = [[] for _ in range(num_bins)] for filename, bin_id in self._result: bins[int(bin_id)].append(filename) @@ -943,7 +940,7 @@ class SortHistogram(SortMethod): Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def __init__(self, arguments: "Namespace", is_group: bool = False) -> None: + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: super().__init__(arguments, loader_type="all", is_group=is_group) method = arguments.group_method if self._is_group else arguments.sort_method self._is_dissim = method == "hist-dissim" @@ -951,7 +948,7 @@ class SortHistogram(SortMethod): def _calc_histogram(self, image: np.ndarray, - alignments: Optional["PNGHeaderAlignmentsDict"]) -> np.ndarray: + alignments: PNGHeaderAlignmentsDict | None) -> np.ndarray: if alignments: image = self._mask_face(image, alignments) return cv2.calcHist([image], [0], None, [256], [0, 256]) @@ -994,7 +991,7 @@ class SortHistogram(SortMethod): self._result[i + 1]) @classmethod - def _get_avg_score(cls, image: np.ndarray, references: List[np.ndarray]) -> float: + def _get_avg_score(cls, image: np.ndarray, references: list[np.ndarray]) -> float: """ Return the average histogram score between a face and reference images Parameters @@ -1015,22 +1012,22 @@ class SortHistogram(SortMethod): scores.append(score) return sum(scores) / len(scores) - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Group into bins by histogram """ msg = "dissimilarity" if self._is_dissim else "similarity" logger.info("Grouping by %s...", msg) # Groups are of the form: group_num -> reference histogram - reference_groups: Dict[int, List[np.ndarray]] = {} + reference_groups: dict[int, list[np.ndarray]] = {} # Bins array, where index is the group number and value is # an array containing the file paths to the images in that group - bins: List[List[str]] = [] + bins: list[list[str]] = [] threshold = self._threshold img_list_len = len(self._result) - reference_groups[0] = [cast(np.ndarray, self._result[0][1])] + reference_groups[0] = [T.cast(np.ndarray, self._result[0][1])] bins.append([self._result[0][0]]) for i in tqdm(range(1, img_list_len), @@ -1045,7 +1042,7 @@ class SortHistogram(SortMethod): current_key, current_score = key, score if current_score < threshold: - reference_groups[cast(int, current_key)].append(self._result[i][1]) + reference_groups[T.cast(int, current_key)].append(self._result[i][1]) bins[current_key].append(self._result[i][0]) else: reference_groups[len(reference_groups)] = [self._result[i][1]] @@ -1055,8 +1052,8 @@ class SortHistogram(SortMethod): def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Collect the histogram for the given face Parameters diff --git a/tools/sort/sort_methods_aligned.py b/tools/sort/sort_methods_aligned.py index d597def8..ecf6dd4a 100644 --- a/tools/sort/sort_methods_aligned.py +++ b/tools/sort/sort_methods_aligned.py @@ -2,11 +2,11 @@ """ Sorting methods that use the properties of a :class:`lib.align.AlignedFace` object to obtain their sorting metrics. """ +from __future__ import annotations import logging import operator import sys - -from typing import Dict, List, Optional, TYPE_CHECKING, Union +import typing as T import numpy as np from tqdm import tqdm @@ -15,7 +15,7 @@ from lib.align import AlignedFace from lib.utils import FaceswapError from .sort_methods import SortMethod -if TYPE_CHECKING: +if T.TYPE_CHECKING: from argparse import Namespace from lib.align.alignments import PNGHeaderAlignmentsDict @@ -36,7 +36,7 @@ class SortAlignedMetric(SortMethod): # pylint:disable=too-few-public-methods Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def _get_metric(self, aligned_face: AlignedFace) -> Union[np.ndarray, float]: + def _get_metric(self, aligned_face: AlignedFace) -> np.ndarray | float: """ Obtain the correct metric for the given sort method" Parameters @@ -58,8 +58,8 @@ class SortAlignedMetric(SortMethod): # pylint:disable=too-few-public-methods def score_image(self, filename: str, - image: Optional[np.ndarray], - alignments: Optional["PNGHeaderAlignmentsDict"]) -> None: + image: np.ndarray | None, + alignments: PNGHeaderAlignmentsDict | None) -> None: """ Score a single image for sort method: "distance", "yaw", "pitch" or "size" and add the result to :attr:`_result` @@ -110,7 +110,7 @@ class SortDistance(SortAlignedMetric): logger.info("Sorting...") self._result = sorted(self._result, key=operator.itemgetter(1), reverse=False) - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Create bins to split linearly from the lowest to the highest sample value Returns @@ -138,7 +138,7 @@ class SortPitch(SortAlignedMetric): """ return aligned_face.pose.pitch - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Create bins from 0 degrees to 180 degrees based on number of bins Allocate item to bin when it is in range of one of the pre-allocated bins @@ -148,7 +148,7 @@ class SortPitch(SortAlignedMetric): list List of bins of filenames """ - thresholds = (np.linspace(90, -90, self._num_bins + 1)) + thresholds = np.linspace(90, -90, self._num_bins + 1) # Start bin names from 0 for more intuitive experience names = np.flip(thresholds.astype("int")) + 90 @@ -157,7 +157,7 @@ class SortPitch(SortAlignedMetric): f"degs_to_{int(names[idx + 1])}degs" for idx in range(self._num_bins)] - bins: List[List[str]] = [[] for _ in range(self._num_bins)] + bins: list[list[str]] = [[] for _ in range(self._num_bins)] for filename, result in self._result: result = np.clip(result, -90.0, 90.0) bin_idx = next(bin_id for bin_id, thresh in enumerate(thresholds) @@ -223,7 +223,7 @@ class SortSize(SortAlignedMetric): size = ((roi[1][0] - roi[0][0]) ** 2 + (roi[1][1] - roi[0][1]) ** 2) ** 0.5 return size - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Create bins to split linearly from the lowest to the highest sample value Allocate item to bin when it is in range of one of the pre-allocated bins @@ -247,7 +247,7 @@ class SortFaceCNN(SortAlignedMetric): Set to ``True`` if this class is going to be called exclusively for binning. Default: ``False`` """ - def __init__(self, arguments: "Namespace", is_group: bool = False) -> None: + def __init__(self, arguments: Namespace, is_group: bool = False) -> None: super().__init__(arguments, is_group=is_group) self._is_dissim = self._method == "face-cnn-dissim" self._threshold: float = 7.2 if arguments.threshold < 1.0 else arguments.threshold @@ -308,7 +308,7 @@ class SortFaceCNN(SortAlignedMetric): logger.info("Sorting...") self._result = sorted(self._result, key=operator.itemgetter(2), reverse=True) - def binning(self) -> List[List[str]]: + def binning(self) -> list[list[str]]: """ Group into bins by CNN face similarity Returns @@ -320,11 +320,11 @@ class SortFaceCNN(SortAlignedMetric): logger.info("Grouping by face-cnn %s...", msg) # Groups are of the form: group_num -> reference faces - reference_groups: Dict[int, List[np.ndarray]] = {} + reference_groups: dict[int, list[np.ndarray]] = {} # Bins array, where index is the group number and value is # an array containing the file paths to the images in that group. - bins: List[List[str]] = [] + bins: list[list[str]] = [] # Comparison threshold used to decide how similar # faces have to be to be grouped together. @@ -362,7 +362,7 @@ class SortFaceCNN(SortAlignedMetric): return bins @classmethod - def _get_avg_score(cls, face: np.ndarray, references: List[np.ndarray]) -> float: + def _get_avg_score(cls, face: np.ndarray, references: list[np.ndarray]) -> float: """ Return the average CNN similarity score between a face and reference images Parameters