1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 19:05:02 -04:00

Bugfix: Correct loss labels when graphing

This commit is contained in:
torzdf 2024-03-20 17:08:39 +00:00
parent 1d3c59c351
commit 9ddc838e68
8 changed files with 70 additions and 42 deletions

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
import re
import typing as T import typing as T
import zlib import zlib
@ -14,6 +15,7 @@ from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors) errors_impl as tf_errors)
from lib.logger import parse_class_init
from lib.serializer import get_serializer from lib.serializer import get_serializer
if T.TYPE_CHECKING: if T.TYPE_CHECKING:
@ -46,7 +48,7 @@ class _LogFiles():
The folder that contains the Tensorboard log files The folder that contains the Tensorboard log files
""" """
def __init__(self, logs_folder: str) -> None: def __init__(self, logs_folder: str) -> None:
logger.debug("Initializing: %s: (logs_folder: '%s')", self.__class__.__name__, logs_folder) logger.debug(parse_class_init(locals()))
self._logs_folder = logs_folder self._logs_folder = logs_folder
self._filenames = self._get_log_filenames() self._filenames = self._get_log_filenames()
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@ -215,7 +217,7 @@ class _CacheData():
class _Cache(): class _Cache():
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """ """ Holds parsed Tensorflow log event data in a compressed cache in memory. """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__) logger.debug(parse_class_init(locals()))
self._data: dict[int, _CacheData] = {} self._data: dict[int, _CacheData] = {}
self._carry_over: dict[int, EventData] = {} self._carry_over: dict[int, EventData] = {}
self._loss_labels: list[str] = [] self._loss_labels: list[str] = []
@ -471,8 +473,7 @@ class TensorBoardLogs():
``True`` if the events are being read whilst Faceswap is training otherwise ``False`` ``True`` if the events are being read whilst Faceswap is training otherwise ``False``
""" """
def __init__(self, logs_folder: str, is_training: bool) -> None: def __init__(self, logs_folder: str, is_training: bool) -> None:
logger.debug("Initializing: %s: (logs_folder: %s, is_training: %s)", logger.debug(parse_class_init(locals()))
self.__class__.__name__, logs_folder, is_training)
self._is_training = False self._is_training = False
self._training_iterator = None self._training_iterator = None
@ -631,12 +632,12 @@ class _EventParser(): # pylint:disable=too-few-public-methods
otherwise ``False`` otherwise ``False``
""" """
def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None: def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None:
logger.debug("Initializing: %s: (iterator: %s, cache: %s, live_data: %s)", logger.debug(parse_class_init(locals()))
self.__class__.__name__, iterator, cache, live_data)
self._live_data = live_data self._live_data = live_data
self._cache = cache self._cache = cache
self._iterator = self._get_latest_live(iterator) if live_data else iterator self._iterator = self._get_latest_live(iterator) if live_data else iterator
self._loss_labels: list[str] = [] self._loss_labels: list[str] = []
self._num_strip = re.compile(r"_\d+$")
logger.debug("Initialized: %s", self.__class__.__name__) logger.debug("Initialized: %s", self.__class__.__name__)
@classmethod @classmethod
@ -728,7 +729,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
if layer["name"] == layer_name)["config"] if layer["name"] == layer_name)["config"]
layer_outputs = self._get_outputs(output_config) layer_outputs = self._get_outputs(output_config)
for output in layer_outputs: # Drill into sub-model to get the actual output names for output in layer_outputs: # Drill into sub-model to get the actual output names
loss_name = output[0][0] loss_name = self._num_strip.sub("", output[0][0]) # strip trailing numbers
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
new_name = f"{loss_name.replace('_both', '')}_{side}" new_name = f"{loss_name.replace('_both', '')}_{side}"
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name) logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)

View file

@ -17,6 +17,7 @@ from threading import Event
import numpy as np import numpy as np
from lib.logger import parse_class_init
from lib.serializer import get_serializer from lib.serializer import get_serializer
from .event_reader import TensorBoardLogs from .event_reader import TensorBoardLogs
@ -30,7 +31,7 @@ class GlobalSession():
:attr:`lib.gui.analysis.Session` :attr:`lib.gui.analysis.Session`
""" """
def __init__(self) -> None: def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__) logger.debug(parse_class_init(locals()))
self._state: dict[str, T.Any] = {} self._state: dict[str, T.Any] = {}
self._model_dir = "" self._model_dir = ""
self._model_name = "" self._model_name = ""
@ -289,7 +290,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
The loaded or currently training session The loaded or currently training session
""" """
def __init__(self, session: GlobalSession) -> None: def __init__(self, session: GlobalSession) -> None:
logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session) logger.debug(parse_class_init(locals()))
self._session = session self._session = session
self._state = session._state self._state = session._state
@ -539,11 +540,7 @@ class Calculations():
avg_samples: int = 500, avg_samples: int = 500,
smooth_amount: float = 0.90, smooth_amount: float = 0.90,
flatten_outliers: bool = False) -> None: flatten_outliers: bool = False) -> None:
logger.debug("Initializing %s: (session_id: %s, display: %s, loss_keys: %s, " logger.debug(parse_class_init(locals()))
"selections: %s, avg_samples: %s, smooth_amount: %s, flatten_outliers: %s)",
self.__class__.__name__, session_id, display, loss_keys, selections,
avg_samples, smooth_amount, flatten_outliers)
warnings.simplefilter("ignore", np.RankWarning) warnings.simplefilter("ignore", np.RankWarning)
self._session_id = session_id self._session_id = session_id
@ -872,6 +869,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
Adapted from: https://stackoverflow.com/questions/42869495 Adapted from: https://stackoverflow.com/questions/42869495
""" """
def __init__(self, data: np.ndarray, amount: float) -> None: def __init__(self, data: np.ndarray, amount: float) -> None:
logger.debug(parse_class_init(locals()))
assert data.ndim == 1 assert data.ndim == 1
amount = min(max(amount, 0.001), 0.999) amount = min(max(amount, 0.001), 0.999)
@ -880,6 +878,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
self._dtype = "float32" if data.dtype == np.float32 else "float64" self._dtype = "float32" if data.dtype == np.float32 else "float64"
self._row_size = self._get_max_row_size() self._row_size = self._get_max_row_size()
self._out = np.empty_like(data, dtype=self._dtype) self._out = np.empty_like(data, dtype=self._dtype)
logger.debug("Initialized %s", self.__class__.__name__)
def __call__(self) -> np.ndarray: def __call__(self) -> np.ndarray:
""" Perform the exponential moving average calculation. """ Perform the exponential moving average calculation.

View file

@ -10,6 +10,8 @@ import gettext
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from lib.logger import parse_class_init
from .display_analysis import Analysis from .display_analysis import Analysis
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
from .utils import get_config from .utils import get_config
@ -31,7 +33,7 @@ class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
""" """
def __init__(self, parent): def __init__(self, parent):
logger.debug("Initializing %s", self.__class__.__name__) logger.debug(parse_class_init(locals()))
super().__init__(parent) super().__init__(parent)
parent.add(self) parent.add(self)
tk_vars = get_config().tk_vars tk_vars = get_config().tk_vars

View file

@ -8,6 +8,8 @@ import os
import tkinter as tk import tkinter as tk
from tkinter import ttk from tkinter import ttk
from lib.logger import parse_class_init
from .custom_widgets import Tooltip from .custom_widgets import Tooltip
from .display_page import DisplayPage from .display_page import DisplayPage
from .popup_session import SessionPopUp from .popup_session import SessionPopUp
@ -36,8 +38,7 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
The help text to display for the summary statistics page The help text to display for the summary statistics page
""" """
def __init__(self, parent, tab_name, helptext): def __init__(self, parent, tab_name, helptext):
logger.debug("Initializing: %s: (parent, %s, tab_name: '%s', helptext: '%s')", logger.debug(parse_class_init(locals()))
self.__class__.__name__, parent, tab_name, helptext)
super().__init__(parent, tab_name, helptext) super().__init__(parent, tab_name, helptext)
self._summary = None self._summary = None
@ -62,10 +63,10 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
dict dict
The dictionary of variable names to tkinter variables The dictionary of variable names to tkinter variables
""" """
return dict(selected_id=tk.StringVar(), return {"selected_id": tk.StringVar(),
refresh_graph=get_config().tk_vars.refresh_graph, "refresh_graph": get_config().tk_vars.refresh_graph,
is_training=get_config().tk_vars.is_training, "is_training": get_config().tk_vars.is_training,
analysis_folder=get_config().tk_vars.analysis_folder) "analysis_folder": get_config().tk_vars.analysis_folder}
def on_tab_select(self): def on_tab_select(self):
""" Callback for when the analysis tab is selected. """ Callback for when the analysis tab is selected.
@ -299,7 +300,7 @@ class _Options(): # pylint:disable=too-few-public-methods
The Analysis Display Tab that holds the options buttons The Analysis Display Tab that holds the options buttons
""" """
def __init__(self, parent): def __init__(self, parent):
logger.debug("Initializing: %s (parent: %s)", self.__class__.__name__, parent) logger.debug(parse_class_init(locals()))
self._parent = parent self._parent = parent
self._buttons = self._add_buttons() self._buttons = self._add_buttons()
self._add_training_callback() self._add_training_callback()
@ -380,8 +381,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
The help text to display for the summary statistics page The help text to display for the summary statistics page
""" """
def __init__(self, parent, selected_id, helptext): def __init__(self, parent, selected_id, helptext):
logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')", logger.debug(parse_class_init(locals()))
self.__class__.__name__, parent, selected_id, helptext)
super().__init__(parent) super().__init__(parent)
self._selected_id = selected_id self._selected_id = selected_id

View file

@ -9,6 +9,7 @@ import typing as T
from tkinter import ttk from tkinter import ttk
from lib.logger import parse_class_init
from lib.training.preview_tk import PreviewTk from lib.training.preview_tk import PreviewTk
from .display_graph import TrainingGraph from .display_graph import TrainingGraph
@ -28,8 +29,7 @@ _ = _LANG.gettext
class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Tab to display output preview images for extract and convert """ """ Tab to display output preview images for extract and convert """
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
logger.debug("Initializing %s (args: %s, kwargs: %s)", logger.debug(parse_class_init(locals()))
self.__class__.__name__, args, kwargs)
self._preview = get_images().preview_extract self._preview = get_images().preview_extract
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__) logger.debug("Initialized %s", self.__class__.__name__)
@ -83,8 +83,7 @@ class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Training preview image(s) """ """ Training preview image(s) """
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
logger.debug("Initializing %s (args: %s, kwargs: %s)", logger.debug(parse_class_init(locals()))
self.__class__.__name__, args, kwargs)
self._preview = get_images().preview_train self._preview = get_images().preview_train
self._display: PreviewTk | None = None self._display: PreviewTk | None = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -172,9 +171,11 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
helptext: str, helptext: str,
wait_time: int, wait_time: int,
command: str | None = None) -> None: command: str | None = None) -> None:
logger.debug(parse_class_init(locals()))
self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"], self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
tuple[tk.BooleanVar, str]] = {} tuple[tk.BooleanVar, str]] = {}
super().__init__(parent, tab_name, helptext, wait_time, command) super().__init__(parent, tab_name, helptext, wait_time, command)
logger.debug("Initialized %s", self.__class__.__name__)
def set_vars(self) -> None: def set_vars(self) -> None:
""" Add graphing specific variables to the default variables. """ Add graphing specific variables to the default variables.
@ -212,7 +213,8 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
Pull latest data and run the tab's update code when the tab is selected. Pull latest data and run the tab's update code when the tab is selected.
""" """
logger.debug("Callback received for '%s' tab", self.tabname) logger.debug("Callback received for '%s' tab (display_item: %s)",
self.tabname, self.display_item)
if self.display_item is not None: if self.display_item is not None:
get_config().tk_vars.refresh_graph.set(True) get_config().tk_vars.refresh_graph.set(True)
self._update_page() self._update_page()

View file

@ -18,6 +18,8 @@ from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
NavigationToolbar2Tk) NavigationToolbar2Tk)
from matplotlib.backend_bases import NavigationToolbar2 from matplotlib.backend_bases import NavigationToolbar2
from lib.logger import parse_class_init
from .custom_widgets import Tooltip from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask from .utils import get_config, get_images, LongRunningTask
@ -40,7 +42,6 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
The data label for the y-axis The data label for the y-axis
""" """
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent) super().__init__(parent)
matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI
style.use("ggplot") style.use("ggplot")
@ -58,7 +59,6 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
self._initiate_graph() self._initiate_graph()
self._update_plot(initiate=True) self._update_plot(initiate=True)
logger.debug("Initialized %s", self.__class__.__name__)
@property @property
def calcs(self): def calcs(self):
@ -335,10 +335,12 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
""" """
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None: def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug(parse_class_init(locals()))
super().__init__(parent, data, ylabel) super().__init__(parent, data, ylabel)
self._thread: LongRunningTask | None = None # Thread for LongRunningTask self._thread: LongRunningTask | None = None # Thread for LongRunningTask
self._displayed_keys: list[str] = [] self._displayed_keys: list[str] = []
self._add_callback() self._add_callback()
logger.debug("Initialized %s", self.__class__.__name__)
def _add_callback(self) -> None: def _add_callback(self) -> None:
""" Add the variable trace to update graph on refresh button press or save iteration. """ """ Add the variable trace to update graph on refresh button press or save iteration. """
@ -427,8 +429,10 @@ class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
Should be one of ``"log"`` or ``"linear"`` Should be one of ``"log"`` or ``"linear"``
""" """
def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None: def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None:
logger.debug(parse_class_init(locals()))
super().__init__(parent, data, ylabel) super().__init__(parent, data, ylabel)
self._scale = scale self._scale = scale
logger.debug("Initialized %s", self.__class__.__name__)
def build(self) -> None: def build(self) -> None:
""" Build the session graph """ """ Build the session graph """
@ -494,7 +498,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
window: ttk.Frame, window: ttk.Frame,
*, *,
pack_toolbar: bool = True) -> None: pack_toolbar: bool = True) -> None:
logger.debug(parse_class_init(locals()))
# Avoid using self.window (prefer self.canvas.get_tk_widget().master), # Avoid using self.window (prefer self.canvas.get_tk_widget().master),
# so that Tool implementations can reuse the methods. # so that Tool implementations can reuse the methods.
@ -528,6 +532,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called
if pack_toolbar: if pack_toolbar:
self.pack(side=tk.BOTTOM, fill=tk.X) self.pack(side=tk.BOTTOM, fill=tk.X)
logger.debug("Initialized %s", self.__class__.__name__)
@staticmethod @staticmethod
def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ,arguments-renamed def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ,arguments-renamed

View file

@ -20,9 +20,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
""" Parent frame holder for each tab. """ Parent frame holder for each tab.
Defines uniform structure for each tab to inherit from """ Defines uniform structure for each tab to inherit from """
def __init__(self, parent, tab_name, helptext): def __init__(self, parent, tab_name, helptext):
logger.debug("Initializing %s: (tab_name: '%s', helptext: %s)", super().__init__(parent)
self.__class__.__name__, tab_name, helptext)
ttk.Frame.__init__(self, parent)
self._parent = parent self._parent = parent
self.running_task = parent.running_task self.running_task = parent.running_task
@ -42,8 +40,6 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW) self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
parent.add(self, text=self.tabname.title()) parent.add(self, text=self.tabname.title())
logger.debug("Initialized %s", self.__class__.__name__,)
@property @property
def _tab_is_active(self): def _tab_is_active(self):
""" bool: ``True`` if the tab currently has focus otherwise ``False`` """ """ bool: ``True`` if the tab currently has focus otherwise ``False`` """
@ -167,9 +163,7 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
""" Parent Context Sensitive Display Tab """ """ Parent Context Sensitive Display Tab """
def __init__(self, parent, tab_name, helptext, wait_time, command=None): def __init__(self, parent, tab_name, helptext, wait_time, command=None):
logger.debug("%s: OptionalPage args: (wait_time: %s, command: %s)", super().__init__(parent, tab_name, helptext)
self.__class__.__name__, wait_time, command)
DisplayPage.__init__(self, parent, tab_name, helptext)
self._waittime = wait_time self._waittime = wait_time
self.command = command self.command = command

View file

@ -13,6 +13,8 @@ import traceback
from datetime import datetime from datetime import datetime
import numpy as np
# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib # TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
def _patched_format(self, record): def _patched_format(self, record):
@ -544,6 +546,28 @@ def crash_log() -> str:
return filename return filename
def _process_value(value: T.Any) -> T.Any:
""" Process the values from a local dict and return in a loggable format
Parameters
----------
value: Any
The dictionary value
Returns
-------
Any
The original or ammended value
"""
if isinstance(value, str):
return f'"{value}"'
if isinstance(value, np.ndarray) and np.prod(value.shape) > 10:
return f'[type: "{type(value).__name__}" shape: {value.shape}, dtype: "{value.dtype}"]'
if isinstance(value, (list, tuple, set)) and len(value) > 10:
return f'[type: "{type(value).__name__}" len: {len(value)}'
return value
def parse_class_init(locals_dict: dict[str, T.Any]) -> str: def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
""" Parse a locals dict from a class and return in a format suitable for logging """ Parse a locals dict from a class and return in a format suitable for logging
Parameters Parameters
@ -555,10 +579,11 @@ def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
str str
The locals information suitable for logging The locals information suitable for logging
""" """
delimit = {k: f"'{v}'" if isinstance(v, str) else v delimit = {k: _process_value(v)
for k, v in locals_dict.items() if k != "self"} for k, v in locals_dict.items() if k != "self"}
dsp = ", ".join(f"{k}: {v}" for k, v in delimit.items()) dsp = ", ".join(f"{k}: {v}" for k, v in delimit.items())
return f"Initializing {locals_dict['self'].__class__.__name__} ({dsp})" dsp = f" ({dsp})" if dsp else ""
return f"Initializing {locals_dict['self'].__class__.__name__}{dsp}"
_OLD_FACTORY = logging.getLogRecordFactory() _OLD_FACTORY = logging.getLogRecordFactory()