mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 19:05:02 -04:00
Bugfix: Correct loss labels when graphing
This commit is contained in:
parent
1d3c59c351
commit
9ddc838e68
8 changed files with 70 additions and 42 deletions
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import typing as T
|
||||
import zlib
|
||||
|
||||
|
@ -14,6 +15,7 @@ from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
|
|||
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
||||
errors_impl as tf_errors)
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
from lib.serializer import get_serializer
|
||||
|
||||
if T.TYPE_CHECKING:
|
||||
|
@ -46,7 +48,7 @@ class _LogFiles():
|
|||
The folder that contains the Tensorboard log files
|
||||
"""
|
||||
def __init__(self, logs_folder: str) -> None:
|
||||
logger.debug("Initializing: %s: (logs_folder: '%s')", self.__class__.__name__, logs_folder)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._logs_folder = logs_folder
|
||||
self._filenames = self._get_log_filenames()
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
@ -215,7 +217,7 @@ class _CacheData():
|
|||
class _Cache():
|
||||
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing: %s", self.__class__.__name__)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._data: dict[int, _CacheData] = {}
|
||||
self._carry_over: dict[int, EventData] = {}
|
||||
self._loss_labels: list[str] = []
|
||||
|
@ -471,8 +473,7 @@ class TensorBoardLogs():
|
|||
``True`` if the events are being read whilst Faceswap is training otherwise ``False``
|
||||
"""
|
||||
def __init__(self, logs_folder: str, is_training: bool) -> None:
|
||||
logger.debug("Initializing: %s: (logs_folder: %s, is_training: %s)",
|
||||
self.__class__.__name__, logs_folder, is_training)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._is_training = False
|
||||
self._training_iterator = None
|
||||
|
||||
|
@ -631,12 +632,12 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
otherwise ``False``
|
||||
"""
|
||||
def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None:
|
||||
logger.debug("Initializing: %s: (iterator: %s, cache: %s, live_data: %s)",
|
||||
self.__class__.__name__, iterator, cache, live_data)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._live_data = live_data
|
||||
self._cache = cache
|
||||
self._iterator = self._get_latest_live(iterator) if live_data else iterator
|
||||
self._loss_labels: list[str] = []
|
||||
self._num_strip = re.compile(r"_\d+$")
|
||||
logger.debug("Initialized: %s", self.__class__.__name__)
|
||||
|
||||
@classmethod
|
||||
|
@ -728,7 +729,7 @@ class _EventParser(): # pylint:disable=too-few-public-methods
|
|||
if layer["name"] == layer_name)["config"]
|
||||
layer_outputs = self._get_outputs(output_config)
|
||||
for output in layer_outputs: # Drill into sub-model to get the actual output names
|
||||
loss_name = output[0][0]
|
||||
loss_name = self._num_strip.sub("", output[0][0]) # strip trailing numbers
|
||||
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
|
||||
new_name = f"{loss_name.replace('_both', '')}_{side}"
|
||||
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
|
||||
|
|
|
@ -17,6 +17,7 @@ from threading import Event
|
|||
|
||||
import numpy as np
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
from lib.serializer import get_serializer
|
||||
|
||||
from .event_reader import TensorBoardLogs
|
||||
|
@ -30,7 +31,7 @@ class GlobalSession():
|
|||
:attr:`lib.gui.analysis.Session`
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._state: dict[str, T.Any] = {}
|
||||
self._model_dir = ""
|
||||
self._model_name = ""
|
||||
|
@ -289,7 +290,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
|
|||
The loaded or currently training session
|
||||
"""
|
||||
def __init__(self, session: GlobalSession) -> None:
|
||||
logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._session = session
|
||||
self._state = session._state
|
||||
|
||||
|
@ -539,11 +540,7 @@ class Calculations():
|
|||
avg_samples: int = 500,
|
||||
smooth_amount: float = 0.90,
|
||||
flatten_outliers: bool = False) -> None:
|
||||
logger.debug("Initializing %s: (session_id: %s, display: %s, loss_keys: %s, "
|
||||
"selections: %s, avg_samples: %s, smooth_amount: %s, flatten_outliers: %s)",
|
||||
self.__class__.__name__, session_id, display, loss_keys, selections,
|
||||
avg_samples, smooth_amount, flatten_outliers)
|
||||
|
||||
logger.debug(parse_class_init(locals()))
|
||||
warnings.simplefilter("ignore", np.RankWarning)
|
||||
|
||||
self._session_id = session_id
|
||||
|
@ -872,6 +869,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
|
|||
Adapted from: https://stackoverflow.com/questions/42869495
|
||||
"""
|
||||
def __init__(self, data: np.ndarray, amount: float) -> None:
|
||||
logger.debug(parse_class_init(locals()))
|
||||
assert data.ndim == 1
|
||||
amount = min(max(amount, 0.001), 0.999)
|
||||
|
||||
|
@ -880,6 +878,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
|
|||
self._dtype = "float32" if data.dtype == np.float32 else "float64"
|
||||
self._row_size = self._get_max_row_size()
|
||||
self._out = np.empty_like(data, dtype=self._dtype)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def __call__(self) -> np.ndarray:
|
||||
""" Perform the exponential moving average calculation.
|
||||
|
|
|
@ -10,6 +10,8 @@ import gettext
|
|||
import tkinter as tk
|
||||
from tkinter import ttk
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
|
||||
from .display_analysis import Analysis
|
||||
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
|
||||
from .utils import get_config
|
||||
|
@ -31,7 +33,7 @@ class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
|
|||
"""
|
||||
|
||||
def __init__(self, parent):
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
super().__init__(parent)
|
||||
parent.add(self)
|
||||
tk_vars = get_config().tk_vars
|
||||
|
|
|
@ -8,6 +8,8 @@ import os
|
|||
import tkinter as tk
|
||||
from tkinter import ttk
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
|
||||
from .custom_widgets import Tooltip
|
||||
from .display_page import DisplayPage
|
||||
from .popup_session import SessionPopUp
|
||||
|
@ -36,8 +38,7 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
The help text to display for the summary statistics page
|
||||
"""
|
||||
def __init__(self, parent, tab_name, helptext):
|
||||
logger.debug("Initializing: %s: (parent, %s, tab_name: '%s', helptext: '%s')",
|
||||
self.__class__.__name__, parent, tab_name, helptext)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
super().__init__(parent, tab_name, helptext)
|
||||
self._summary = None
|
||||
|
||||
|
@ -62,10 +63,10 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
dict
|
||||
The dictionary of variable names to tkinter variables
|
||||
"""
|
||||
return dict(selected_id=tk.StringVar(),
|
||||
refresh_graph=get_config().tk_vars.refresh_graph,
|
||||
is_training=get_config().tk_vars.is_training,
|
||||
analysis_folder=get_config().tk_vars.analysis_folder)
|
||||
return {"selected_id": tk.StringVar(),
|
||||
"refresh_graph": get_config().tk_vars.refresh_graph,
|
||||
"is_training": get_config().tk_vars.is_training,
|
||||
"analysis_folder": get_config().tk_vars.analysis_folder}
|
||||
|
||||
def on_tab_select(self):
|
||||
""" Callback for when the analysis tab is selected.
|
||||
|
@ -299,7 +300,7 @@ class _Options(): # pylint:disable=too-few-public-methods
|
|||
The Analysis Display Tab that holds the options buttons
|
||||
"""
|
||||
def __init__(self, parent):
|
||||
logger.debug("Initializing: %s (parent: %s)", self.__class__.__name__, parent)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._parent = parent
|
||||
self._buttons = self._add_buttons()
|
||||
self._add_training_callback()
|
||||
|
@ -380,8 +381,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
The help text to display for the summary statistics page
|
||||
"""
|
||||
def __init__(self, parent, selected_id, helptext):
|
||||
logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')",
|
||||
self.__class__.__name__, parent, selected_id, helptext)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
super().__init__(parent)
|
||||
self._selected_id = selected_id
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import typing as T
|
|||
|
||||
from tkinter import ttk
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
from lib.training.preview_tk import PreviewTk
|
||||
|
||||
from .display_graph import TrainingGraph
|
||||
|
@ -28,8 +29,7 @@ _ = _LANG.gettext
|
|||
class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||
""" Tab to display output preview images for extract and convert """
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
logger.debug("Initializing %s (args: %s, kwargs: %s)",
|
||||
self.__class__.__name__, args, kwargs)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._preview = get_images().preview_extract
|
||||
super().__init__(*args, **kwargs)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
@ -83,8 +83,7 @@ class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
||||
""" Training preview image(s) """
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
logger.debug("Initializing %s (args: %s, kwargs: %s)",
|
||||
self.__class__.__name__, args, kwargs)
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._preview = get_images().preview_train
|
||||
self._display: PreviewTk | None = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -172,9 +171,11 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
helptext: str,
|
||||
wait_time: int,
|
||||
command: str | None = None) -> None:
|
||||
logger.debug(parse_class_init(locals()))
|
||||
self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
|
||||
tuple[tk.BooleanVar, str]] = {}
|
||||
super().__init__(parent, tab_name, helptext, wait_time, command)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def set_vars(self) -> None:
|
||||
""" Add graphing specific variables to the default variables.
|
||||
|
@ -212,7 +213,8 @@ class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors
|
|||
|
||||
Pull latest data and run the tab's update code when the tab is selected.
|
||||
"""
|
||||
logger.debug("Callback received for '%s' tab", self.tabname)
|
||||
logger.debug("Callback received for '%s' tab (display_item: %s)",
|
||||
self.tabname, self.display_item)
|
||||
if self.display_item is not None:
|
||||
get_config().tk_vars.refresh_graph.set(True)
|
||||
self._update_page()
|
||||
|
|
|
@ -18,6 +18,8 @@ from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg,
|
|||
NavigationToolbar2Tk)
|
||||
from matplotlib.backend_bases import NavigationToolbar2
|
||||
|
||||
from lib.logger import parse_class_init
|
||||
|
||||
from .custom_widgets import Tooltip
|
||||
from .utils import get_config, get_images, LongRunningTask
|
||||
|
||||
|
@ -40,7 +42,6 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
The data label for the y-axis
|
||||
"""
|
||||
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
|
||||
logger.debug("Initializing %s", self.__class__.__name__)
|
||||
super().__init__(parent)
|
||||
matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI
|
||||
style.use("ggplot")
|
||||
|
@ -58,7 +59,6 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
|
||||
self._initiate_graph()
|
||||
self._update_plot(initiate=True)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@property
|
||||
def calcs(self):
|
||||
|
@ -335,10 +335,12 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
"""
|
||||
|
||||
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
|
||||
logger.debug(parse_class_init(locals()))
|
||||
super().__init__(parent, data, ylabel)
|
||||
self._thread: LongRunningTask | None = None # Thread for LongRunningTask
|
||||
self._displayed_keys: list[str] = []
|
||||
self._add_callback()
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def _add_callback(self) -> None:
|
||||
""" Add the variable trace to update graph on refresh button press or save iteration. """
|
||||
|
@ -427,8 +429,10 @@ class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
|
|||
Should be one of ``"log"`` or ``"linear"``
|
||||
"""
|
||||
def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None:
|
||||
logger.debug(parse_class_init(locals()))
|
||||
super().__init__(parent, data, ylabel)
|
||||
self._scale = scale
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
def build(self) -> None:
|
||||
""" Build the session graph """
|
||||
|
@ -494,7 +498,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
window: ttk.Frame,
|
||||
*,
|
||||
pack_toolbar: bool = True) -> None:
|
||||
|
||||
logger.debug(parse_class_init(locals()))
|
||||
# Avoid using self.window (prefer self.canvas.get_tk_widget().master),
|
||||
# so that Tool implementations can reuse the methods.
|
||||
|
||||
|
@ -528,6 +532,7 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances
|
|||
NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called
|
||||
if pack_toolbar:
|
||||
self.pack(side=tk.BOTTOM, fill=tk.X)
|
||||
logger.debug("Initialized %s", self.__class__.__name__)
|
||||
|
||||
@staticmethod
|
||||
def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ,arguments-renamed
|
||||
|
|
|
@ -20,9 +20,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
""" Parent frame holder for each tab.
|
||||
Defines uniform structure for each tab to inherit from """
|
||||
def __init__(self, parent, tab_name, helptext):
|
||||
logger.debug("Initializing %s: (tab_name: '%s', helptext: %s)",
|
||||
self.__class__.__name__, tab_name, helptext)
|
||||
ttk.Frame.__init__(self, parent)
|
||||
super().__init__(parent)
|
||||
|
||||
self._parent = parent
|
||||
self.running_task = parent.running_task
|
||||
|
@ -42,8 +40,6 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
|
|||
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
|
||||
parent.add(self, text=self.tabname.title())
|
||||
|
||||
logger.debug("Initialized %s", self.__class__.__name__,)
|
||||
|
||||
@property
|
||||
def _tab_is_active(self):
|
||||
""" bool: ``True`` if the tab currently has focus otherwise ``False`` """
|
||||
|
@ -167,9 +163,7 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
|
|||
""" Parent Context Sensitive Display Tab """
|
||||
|
||||
def __init__(self, parent, tab_name, helptext, wait_time, command=None):
|
||||
logger.debug("%s: OptionalPage args: (wait_time: %s, command: %s)",
|
||||
self.__class__.__name__, wait_time, command)
|
||||
DisplayPage.__init__(self, parent, tab_name, helptext)
|
||||
super().__init__(parent, tab_name, helptext)
|
||||
|
||||
self._waittime = wait_time
|
||||
self.command = command
|
||||
|
|
|
@ -13,6 +13,8 @@ import traceback
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
|
||||
def _patched_format(self, record):
|
||||
|
@ -544,6 +546,28 @@ def crash_log() -> str:
|
|||
return filename
|
||||
|
||||
|
||||
def _process_value(value: T.Any) -> T.Any:
|
||||
""" Process the values from a local dict and return in a loggable format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value: Any
|
||||
The dictionary value
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
The original or ammended value
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
if isinstance(value, np.ndarray) and np.prod(value.shape) > 10:
|
||||
return f'[type: "{type(value).__name__}" shape: {value.shape}, dtype: "{value.dtype}"]'
|
||||
if isinstance(value, (list, tuple, set)) and len(value) > 10:
|
||||
return f'[type: "{type(value).__name__}" len: {len(value)}'
|
||||
return value
|
||||
|
||||
|
||||
def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
|
||||
""" Parse a locals dict from a class and return in a format suitable for logging
|
||||
Parameters
|
||||
|
@ -555,10 +579,11 @@ def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
|
|||
str
|
||||
The locals information suitable for logging
|
||||
"""
|
||||
delimit = {k: f"'{v}'" if isinstance(v, str) else v
|
||||
delimit = {k: _process_value(v)
|
||||
for k, v in locals_dict.items() if k != "self"}
|
||||
dsp = ", ".join(f"{k}: {v}" for k, v in delimit.items())
|
||||
return f"Initializing {locals_dict['self'].__class__.__name__} ({dsp})"
|
||||
dsp = f" ({dsp})" if dsp else ""
|
||||
return f"Initializing {locals_dict['self'].__class__.__name__}{dsp}"
|
||||
|
||||
|
||||
_OLD_FACTORY = logging.getLogRecordFactory()
|
||||
|
|
Loading…
Add table
Reference in a new issue