mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test
802 lines
32 KiB
Python
802 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
""" Handles the loading and collation of events from Tensorflow event log files. """
|
|
from __future__ import annotations
|
|
import logging
|
|
import os
|
|
import typing as T
|
|
import zlib
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
|
|
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
|
|
errors_impl as tf_errors)
|
|
|
|
from lib.serializer import get_serializer
|
|
|
|
if T.TYPE_CHECKING:
|
|
from collections.abc import Generator, Iterator
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
@dataclass
|
|
class EventData:
|
|
""" Holds data collected from Tensorflow Event Files
|
|
|
|
Parameters
|
|
----------
|
|
timestamp: float
|
|
The timestamp of the event step (iteration)
|
|
loss: list[float]
|
|
The loss values collected for A and B sides for the event step
|
|
"""
|
|
timestamp: float = 0.0
|
|
loss: list[float] = field(default_factory=list)
|
|
|
|
|
|
class _LogFiles():
|
|
""" Holds the filenames of the Tensorflow Event logs that require parsing.
|
|
|
|
Parameters
|
|
----------
|
|
logs_folder: str
|
|
The folder that contains the Tensorboard log files
|
|
"""
|
|
def __init__(self, logs_folder: str) -> None:
|
|
logger.debug("Initializing: %s: (logs_folder: '%s')", self.__class__.__name__, logs_folder)
|
|
self._logs_folder = logs_folder
|
|
self._filenames = self._get_log_filenames()
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def session_ids(self) -> list[int]:
|
|
""" list[int]: Sorted list of `ints` of available session ids. """
|
|
return list(sorted(self._filenames))
|
|
|
|
def _get_log_filenames(self) -> dict[int, str]:
|
|
""" Get the Tensorflow event filenames for all existing sessions.
|
|
|
|
Returns
|
|
-------
|
|
dict[int, str]
|
|
The full path of each log file for each training session id that has been run
|
|
"""
|
|
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
|
|
retval: dict[int, str] = {}
|
|
for dirpath, _, filenames in os.walk(self._logs_folder):
|
|
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
|
|
continue
|
|
session_id = self._get_session_id(dirpath)
|
|
if session_id is None:
|
|
logger.warning("Unable to load session data for model")
|
|
return retval
|
|
retval[session_id] = self._get_log_filename(dirpath, filenames)
|
|
logger.debug("logfiles: %s", retval)
|
|
return retval
|
|
|
|
@classmethod
|
|
def _get_session_id(cls, folder: str) -> int | None:
|
|
""" Obtain the session id for the given folder.
|
|
|
|
Parameters
|
|
----------
|
|
folder: str
|
|
The full path to the folder that contains the session's Tensorflow Event Log
|
|
|
|
Returns
|
|
-------
|
|
int or ``None``
|
|
The session ID for the given folder. If no session id can be determined, return
|
|
``None``
|
|
"""
|
|
session = os.path.split(os.path.split(folder)[0])[1]
|
|
session_id = session[session.rfind("_") + 1:]
|
|
retval = None if not session_id.isdigit() else int(session_id)
|
|
logger.debug("folder: '%s', session_id: %s", folder, retval)
|
|
return retval
|
|
|
|
@classmethod
|
|
def _get_log_filename(cls, folder: str, filenames: list[str]) -> str:
|
|
""" Obtain the session log file for the given folder. If multiple log files exist for the
|
|
given folder, then the most recent log file is used, as earlier files are assumed to be
|
|
obsolete.
|
|
|
|
Parameters
|
|
----------
|
|
folder: str
|
|
The full path to the folder that contains the session's Tensorflow Event Log
|
|
filenames: list[str]
|
|
List of filenames that exist within the given folder
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The full path of the selected log file
|
|
"""
|
|
logfiles = [fname for fname in filenames if fname.startswith("events.out.tfevents")]
|
|
retval = os.path.join(folder, sorted(logfiles)[-1]) # Take last item if multi matches
|
|
logger.debug("logfiles: %s, selected: '%s'", logfiles, retval)
|
|
return retval
|
|
|
|
def refresh(self) -> None:
|
|
""" Refresh the list of log filenames. """
|
|
logger.debug("Refreshing log filenames")
|
|
self._filenames = self._get_log_filenames()
|
|
|
|
def get(self, session_id: int) -> str:
|
|
""" Obtain the log filename for the given session id.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int
|
|
The session id to obtain the log filename for
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The full path to the log file for the requested session id
|
|
"""
|
|
retval = self._filenames.get(session_id, "")
|
|
logger.debug("session_id: %s, log_filename: '%s'", session_id, retval)
|
|
return retval
|
|
|
|
|
|
class _CacheData():
|
|
""" Holds cached data that has been retrieved from Tensorflow Event Files and is compressed
|
|
in memory for a single or live training session
|
|
|
|
Parameters
|
|
----------
|
|
labels: list[str]
|
|
The labels for the loss values
|
|
timestamps: :class:`np.ndarray`
|
|
The timestamp of the event step (iteration)
|
|
loss: :class:`np.ndarray`
|
|
The loss values collected for A and B sides for the session
|
|
"""
|
|
def __init__(self, labels: list[str], timestamps: np.ndarray, loss: np.ndarray) -> None:
|
|
self.labels = labels
|
|
self._loss = zlib.compress(T.cast(bytes, loss))
|
|
self._timestamps = zlib.compress(T.cast(bytes, timestamps))
|
|
self._timestamps_shape = timestamps.shape
|
|
self._loss_shape = loss.shape
|
|
|
|
@property
|
|
def loss(self) -> np.ndarray:
|
|
""" :class:`numpy.ndarray`: The loss values for this session """
|
|
retval: np.ndarray = np.frombuffer(zlib.decompress(self._loss), dtype="float32")
|
|
if len(self._loss_shape) > 1:
|
|
retval = retval.reshape(-1, *self._loss_shape[1:])
|
|
return retval
|
|
|
|
@property
|
|
def timestamps(self) -> np.ndarray:
|
|
""" :class:`numpy.ndarray`: The timestamps for this session """
|
|
retval: np.ndarray = np.frombuffer(zlib.decompress(self._timestamps), dtype="float64")
|
|
if len(self._timestamps_shape) > 1:
|
|
retval = retval.reshape(-1, *self._timestamps_shape[1:])
|
|
return retval
|
|
|
|
def add_live_data(self, timestamps: np.ndarray, loss: np.ndarray) -> None:
|
|
""" Add live data to the end of the stored data
|
|
|
|
loss: :class:`numpy.ndarray`
|
|
The latest loss values to add to the cache
|
|
timestamps: :class:`numpy.ndarray`
|
|
The latest timestamps to add to the cache
|
|
"""
|
|
new_buffer: list[bytes] = []
|
|
new_shapes: list[tuple[int, ...]] = []
|
|
for data, buffer, dtype, shape in zip([timestamps, loss],
|
|
[self._timestamps, self._loss],
|
|
["float64", "float32"],
|
|
[self._timestamps_shape, self._loss_shape]):
|
|
|
|
old = np.frombuffer(zlib.decompress(buffer), dtype=dtype)
|
|
if data.ndim > 1:
|
|
old = old.reshape(-1, *data.shape[1:])
|
|
|
|
new = np.concatenate((old, data))
|
|
|
|
logger.debug("old_shape: %s new_shape: %s", shape, new.shape)
|
|
new_buffer.append(zlib.compress(new))
|
|
new_shapes.append(new.shape)
|
|
del old
|
|
|
|
self._timestamps = new_buffer[0]
|
|
self._loss = new_buffer[1]
|
|
self._timestamps_shape = new_shapes[0]
|
|
self._loss_shape = new_shapes[1]
|
|
|
|
|
|
class _Cache():
|
|
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """
|
|
def __init__(self) -> None:
|
|
logger.debug("Initializing: %s", self.__class__.__name__)
|
|
self._data: dict[int, _CacheData] = {}
|
|
self._carry_over: dict[int, EventData] = {}
|
|
self._loss_labels: list[str] = []
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
def is_cached(self, session_id: int) -> bool:
|
|
""" Check if the given session_id's data is already cached
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int
|
|
The session ID to check
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
``True`` if the data already exists in the cache otherwise ``False``.
|
|
"""
|
|
return self._data.get(session_id) is not None
|
|
|
|
def cache_data(self,
|
|
session_id: int,
|
|
data: dict[int, EventData],
|
|
labels: list[str],
|
|
is_live: bool = False) -> None:
|
|
""" Add a full session's worth of event data to :attr:`_data`.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int
|
|
The session id to add the data for
|
|
data[int, :class:`EventData`]
|
|
The extracted event data dictionary generated from :class:`_EventParser`
|
|
labels: list[str]
|
|
List of `str` for the labels of each loss value output
|
|
is_live: bool, optional
|
|
``True`` if the data to be cached is from a live training session otherwise ``False``.
|
|
Default: ``False``
|
|
"""
|
|
logger.debug("Caching event data: (session_id: %s, labels: %s, data points: %s, "
|
|
"is_live: %s)", session_id, labels, len(data), is_live)
|
|
|
|
if labels:
|
|
logger.debug("Setting loss labels: %s", labels)
|
|
self._loss_labels = labels
|
|
|
|
if not data:
|
|
logger.debug("No data to cache")
|
|
return
|
|
|
|
timestamps, loss = self._to_numpy(data, is_live)
|
|
|
|
if not is_live or (is_live and not self._data.get(session_id)):
|
|
self._data[session_id] = _CacheData(self._loss_labels, timestamps, loss)
|
|
else:
|
|
self._add_latest_live(session_id, loss, timestamps)
|
|
|
|
def _to_numpy(self,
|
|
data: dict[int, EventData],
|
|
is_live: bool) -> tuple[np.ndarray, np.ndarray]:
|
|
""" Extract each individual step data into separate numpy arrays for loss and timestamps.
|
|
|
|
Timestamps are stored float64 as the extra accuracy is needed for correct timings. Arrays
|
|
are returned at the length of the shortest available data (i.e. truncated records are
|
|
dropped)
|
|
|
|
Parameters
|
|
----------
|
|
data: dict
|
|
The incoming tensorflow event data in dictionary form per step
|
|
is_live: bool, optional
|
|
``True`` if the data to be cached is from a live training session otherwise ``False``.
|
|
Default: ``False``
|
|
|
|
Returns
|
|
-------
|
|
timestamps: :class:`numpy.ndarray`
|
|
float64 array of all iteration's timestamps
|
|
loss: :class:`numpy.ndarray`
|
|
float32 array of all iteration's loss
|
|
"""
|
|
if is_live and self._carry_over:
|
|
logger.debug("Processing carry over: %s", self._carry_over)
|
|
self._collect_carry_over(data)
|
|
|
|
times, loss = self._process_data(data, is_live)
|
|
|
|
if is_live and not all(len(val) == len(self._loss_labels) for val in loss):
|
|
# TODO Many attempts have been made to fix this for live graph logging, and the issue
|
|
# of non-consistent loss record sizes keeps coming up. In the meantime we shall swallow
|
|
# any loss values that are of incorrect length so graph remains functional. This will,
|
|
# most likely, lead to a mismatch on iteration count so a proper fix should be
|
|
# implemented.
|
|
|
|
# Timestamps and loss appears to remain consistent with each other, but sometimes loss
|
|
# appears non-consistent. eg (lengths):
|
|
# [2, 2, 2, 2, 2, 2, 2, 0] - last loss collection has zero length
|
|
# [1, 2, 2, 2, 2, 2, 2, 2] - 1st loss collection has 1 length
|
|
# [2, 2, 2, 3, 2, 2, 2] - 4th loss collection has 3 length
|
|
|
|
logger.debug("Inconsistent loss found in collection: %s", loss)
|
|
for idx in reversed(range(len(loss))):
|
|
if len(loss[idx]) != len(self._loss_labels):
|
|
logger.debug("Removing loss/timestamps at position %s", idx)
|
|
del loss[idx]
|
|
del times[idx]
|
|
|
|
n_times, n_loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
|
|
logger.debug("Converted to numpy: (data points: %s, timestamps shape: %s, loss shape: %s)",
|
|
len(data), n_times.shape, n_loss.shape)
|
|
|
|
return n_times, n_loss
|
|
|
|
def _collect_carry_over(self, data: dict[int, EventData]) -> None:
|
|
""" For live data, collect carried over data from the previous update and merge into the
|
|
current data dictionary.
|
|
|
|
Parameters
|
|
----------
|
|
data: dict[int, :class:`EventData`]
|
|
The latest raw data dictionary
|
|
"""
|
|
logger.debug("Carry over keys: %s, data keys: %s", list(self._carry_over), list(data))
|
|
for key in list(self._carry_over):
|
|
if key not in data:
|
|
logger.debug("Carry over found for item %s which does not exist in current "
|
|
"data: %s. Skipping.", key, list(data))
|
|
continue
|
|
carry_over = self._carry_over.pop(key)
|
|
update = data[key]
|
|
logger.debug("Merging carry over data: %s in to %s", carry_over, update)
|
|
timestamp = update.timestamp
|
|
update.timestamp = carry_over.timestamp if not timestamp else timestamp
|
|
update.loss = carry_over.loss + update.loss
|
|
logger.debug("Merged carry over data: %s", update)
|
|
|
|
def _process_data(self,
|
|
data: dict[int, EventData],
|
|
is_live: bool) -> tuple[list[float], list[list[float]]]:
|
|
""" Process live update data.
|
|
|
|
Live data requires different processing as often we will only have partial data for the
|
|
current step, so we need to cache carried over partial data to be picked up at the next
|
|
query. In addition to this, if training is unexpectedly interrupted, there may also be
|
|
partial data which needs to be cleansed prior to creating a numpy array
|
|
|
|
Parameters
|
|
----------
|
|
data: dict
|
|
The incoming tensorflow event data in dictionary form per step
|
|
is_live: bool
|
|
``True`` if the data to be cached is from a live training session otherwise ``False``.
|
|
|
|
Returns
|
|
-------
|
|
timestamps: tuple
|
|
Cleaned list of complete timestamps for the latest live query
|
|
loss: list
|
|
Cleaned list of complete loss for the latest live query
|
|
"""
|
|
timestamps, loss = zip(*[(data[idx].timestamp, data[idx].loss)
|
|
for idx in sorted(data)])
|
|
|
|
l_loss: list[list[float]] = list(loss)
|
|
l_timestamps: list[float] = list(timestamps)
|
|
|
|
if len(l_loss[-1]) != len(self._loss_labels):
|
|
logger.debug("Truncated loss found. loss count: %s", len(l_loss))
|
|
idx = sorted(data)[-1]
|
|
if is_live:
|
|
logger.debug("Setting carried over data: %s", data[idx])
|
|
self._carry_over[idx] = data[idx]
|
|
logger.debug("Removing truncated loss: (timestamp: %s, loss: %s)",
|
|
l_timestamps[-1], loss[-1])
|
|
del l_loss[-1]
|
|
del l_timestamps[-1]
|
|
|
|
return l_timestamps, l_loss
|
|
|
|
def _add_latest_live(self, session_id: int, loss: np.ndarray, timestamps: np.ndarray) -> None:
|
|
""" Append the latest received live training data to the cached data.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int
|
|
The training session ID to update the cache for
|
|
loss: :class:`numpy.ndarray`
|
|
The latest loss values returned from the iterator
|
|
timestamps: :class:`numpy.ndarray`
|
|
The latest time stamps returned from the iterator
|
|
"""
|
|
logger.debug("Adding live data to cache: (session_id: %s, loss: %s, timestamps: %s)",
|
|
session_id, loss.shape, timestamps.shape)
|
|
if not np.any(loss) and not np.any(timestamps):
|
|
return
|
|
|
|
self._data[session_id].add_live_data(timestamps, loss)
|
|
|
|
def get_data(self, session_id: int, metric: T.Literal["loss", "timestamps"]
|
|
) -> dict[int, dict[str, np.ndarray | list[str]]] | None:
|
|
""" Retrieve the decompressed cached data from the cache for the given session id.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int or ``None``
|
|
If session_id is provided, then the cached data for that session is returned. If
|
|
session_id is ``None`` then the cached data for all sessions is returned
|
|
metric: ['loss', 'timestamps']
|
|
The metric to return the data for.
|
|
|
|
Returns
|
|
-------
|
|
dict or ``None``
|
|
The `session_id`(s) as key, the values are a dictionary containing the requested
|
|
metric information for each session returned. ``None`` if no data is stored for the
|
|
given session_id
|
|
"""
|
|
if session_id is None:
|
|
raw = self._data
|
|
else:
|
|
data = self._data.get(session_id)
|
|
if not data:
|
|
return None
|
|
raw = {session_id: data}
|
|
|
|
retval: dict[int, dict[str, np.ndarray | list[str]]] = {}
|
|
for idx, data in raw.items():
|
|
array = data.loss if metric == "loss" else data.timestamps
|
|
val: dict[str, np.ndarray | list[str]] = {str(metric): array}
|
|
if metric == "loss":
|
|
val["labels"] = data.labels
|
|
retval[idx] = val
|
|
|
|
logger.debug("Obtained cached data: %s",
|
|
{session_id: {k: v.shape if isinstance(v, np.ndarray) else v
|
|
for k, v in data.items()}
|
|
for session_id, data in retval.items()})
|
|
return retval
|
|
|
|
|
|
class TensorBoardLogs():
|
|
""" Parse data from TensorBoard logs.
|
|
|
|
Process the input logs folder and stores the individual filenames per session.
|
|
|
|
Caches timestamp and loss data on request and returns this data from the cache.
|
|
|
|
Parameters
|
|
----------
|
|
logs_folder: str
|
|
The folder that contains the Tensorboard log files
|
|
is_training: bool
|
|
``True`` if the events are being read whilst Faceswap is training otherwise ``False``
|
|
"""
|
|
def __init__(self, logs_folder: str, is_training: bool) -> None:
|
|
logger.debug("Initializing: %s: (logs_folder: %s, is_training: %s)",
|
|
self.__class__.__name__, logs_folder, is_training)
|
|
self._is_training = False
|
|
self._training_iterator = None
|
|
|
|
self._log_files = _LogFiles(logs_folder)
|
|
self.set_training(is_training)
|
|
|
|
self._cache = _Cache()
|
|
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
@property
|
|
def session_ids(self) -> list[int]:
|
|
""" list[int]: Sorted list of integers of available session ids. """
|
|
return self._log_files.session_ids
|
|
|
|
def set_training(self, is_training: bool) -> None:
|
|
""" Set the internal training flag to the given `is_training` value.
|
|
|
|
If a new training session is being instigated, refresh the log filenames
|
|
|
|
Parameters
|
|
----------
|
|
is_training: bool
|
|
``True`` to indicate that the logs to be read are from the currently training
|
|
session otherwise ``False``
|
|
"""
|
|
if self._is_training == is_training:
|
|
logger.debug("Training flag already set to %s. Returning", is_training)
|
|
return
|
|
|
|
logger.debug("Setting is_training to %s", is_training)
|
|
self._is_training = is_training
|
|
if is_training:
|
|
self._log_files.refresh()
|
|
log_file = self._log_files.get(self.session_ids[-1])
|
|
logger.debug("Setting training iterator for log file: '%s'", log_file)
|
|
self._training_iterator = tf.compat.v1.io.tf_record_iterator(log_file)
|
|
else:
|
|
logger.debug("Removing training iterator")
|
|
del self._training_iterator
|
|
self._training_iterator = None
|
|
|
|
def _cache_data(self, session_id: int) -> None:
|
|
""" Cache TensorBoard logs for the given session ID on first access.
|
|
|
|
Populates :attr:`_cache` with timestamps and loss data.
|
|
|
|
If this is a training session and the data is being queried for the training session ID
|
|
then get the latest available data and append to the cache
|
|
|
|
Parameters
|
|
-------
|
|
session_id: int
|
|
The session ID to cache the data for
|
|
"""
|
|
live_data = self._is_training and session_id == max(self.session_ids)
|
|
iterator = self._training_iterator if live_data else tf.compat.v1.io.tf_record_iterator(
|
|
self._log_files.get(session_id))
|
|
assert iterator is not None
|
|
parser = _EventParser(iterator, self._cache, live_data)
|
|
parser.cache_events(session_id)
|
|
|
|
def _check_cache(self, session_id: int | None = None) -> None:
|
|
""" Check if the given session_id has been cached and if not, cache it.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int, optional
|
|
The Session ID to return the data for. Set to ``None`` to return all session
|
|
data. Default ``None`
|
|
"""
|
|
if session_id is not None and not self._cache.is_cached(session_id):
|
|
self._cache_data(session_id)
|
|
elif self._is_training and session_id == self.session_ids[-1]:
|
|
self._cache_data(session_id)
|
|
elif session_id is None:
|
|
for idx in self.session_ids:
|
|
if not self._cache.is_cached(idx):
|
|
self._cache_data(idx)
|
|
|
|
def get_loss(self, session_id: int | None = None) -> dict[int, dict[str, np.ndarray]]:
|
|
""" Read the loss from the TensorBoard event logs
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int, optional
|
|
The Session ID to return the loss for. Set to ``None`` to return all session
|
|
losses. Default ``None``
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The session id(s) as key, with a further dictionary as value containing the loss name
|
|
and list of loss values for each step
|
|
"""
|
|
logger.debug("Getting loss: (session_id: %s)", session_id)
|
|
retval: dict[int, dict[str, np.ndarray]] = {}
|
|
for idx in [session_id] if session_id else self.session_ids:
|
|
self._check_cache(idx)
|
|
full_data = self._cache.get_data(idx, "loss")
|
|
if not full_data:
|
|
continue
|
|
data = full_data[idx]
|
|
loss = data["loss"]
|
|
assert isinstance(loss, np.ndarray)
|
|
retval[idx] = {title: loss[:, idx] for idx, title in enumerate(data["labels"])}
|
|
|
|
logger.debug({key: {k: v.shape for k, v in val.items()}
|
|
for key, val in retval.items()})
|
|
return retval
|
|
|
|
def get_timestamps(self, session_id: int | None = None) -> dict[int, np.ndarray]:
|
|
""" Read the timestamps from the TensorBoard logs.
|
|
|
|
As loss timestamps are slightly different for each loss, we collect the timestamp from the
|
|
`batch_loss` key.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int, optional
|
|
The Session ID to return the timestamps for. Set to ``None`` to return all session
|
|
timestamps. Default ``None``
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The session id(s) as key with list of timestamps per step as value
|
|
"""
|
|
|
|
logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
|
|
session_id, self._is_training)
|
|
retval: dict[int, np.ndarray] = {}
|
|
for idx in [session_id] if session_id else self.session_ids:
|
|
self._check_cache(idx)
|
|
data = self._cache.get_data(idx, "timestamps")
|
|
if not data:
|
|
continue
|
|
timestamps = data[idx]["timestamps"]
|
|
assert isinstance(timestamps, np.ndarray)
|
|
retval[idx] = timestamps
|
|
logger.debug({k: v.shape for k, v in retval.items()})
|
|
return retval
|
|
|
|
|
|
class _EventParser(): # pylint:disable=too-few-public-methods
|
|
""" Parses Tensorflow event and populates data to :class:`_Cache`.
|
|
|
|
Parameters
|
|
----------
|
|
iterator: :func:`tf.compat.v1.io.tf_record_iterator`
|
|
The iterator to use for reading Tensorflow event logs
|
|
cache: :class:`_Cache`
|
|
The cache object to store the collected parsed events to
|
|
live_data: bool
|
|
``True`` if the iterator to be loaded is a training iterator for reading live data
|
|
otherwise ``False``
|
|
"""
|
|
def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None:
|
|
logger.debug("Initializing: %s: (iterator: %s, cache: %s, live_data: %s)",
|
|
self.__class__.__name__, iterator, cache, live_data)
|
|
self._live_data = live_data
|
|
self._cache = cache
|
|
self._iterator = self._get_latest_live(iterator) if live_data else iterator
|
|
self._loss_labels: list[str] = []
|
|
logger.debug("Initialized: %s", self.__class__.__name__)
|
|
|
|
@classmethod
|
|
def _get_latest_live(cls, iterator: Iterator[bytes]) -> Generator[bytes, None, None]:
|
|
""" Obtain the latest event logs for live training data.
|
|
|
|
The live data iterator remains open so that it can be re-queried
|
|
|
|
Parameters
|
|
----------
|
|
iterator: :func:`tf.compat.v1.io.tf_record_iterator`
|
|
The live training iterator to use for reading Tensorflow event logs
|
|
|
|
Yields
|
|
------
|
|
dict
|
|
A Tensorflow event in dictionary form for a single step
|
|
"""
|
|
i = 0
|
|
while True:
|
|
try:
|
|
yield next(iterator)
|
|
i += 1
|
|
except StopIteration:
|
|
logger.debug("End of data reached")
|
|
break
|
|
except tf.errors.DataLossError as err:
|
|
# Truncated records are ignored. The iterator holds the offset, so the record will
|
|
# be completed at the next call.
|
|
logger.debug("Truncated record. Original Error: %s", err)
|
|
break
|
|
logger.debug("Collected %s records from live log file", i)
|
|
|
|
def cache_events(self, session_id: int) -> None:
|
|
""" Parse the Tensorflow events logs and add to :attr:`_cache`.
|
|
|
|
Parameters
|
|
----------
|
|
session_id: int
|
|
The session id that the data is being cached for
|
|
"""
|
|
assert self._iterator is not None
|
|
data: dict[int, EventData] = {}
|
|
try:
|
|
for record in self._iterator:
|
|
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
|
|
if not event.summary.value:
|
|
continue
|
|
if event.summary.value[0].tag == "keras":
|
|
self._parse_outputs(event)
|
|
if event.summary.value[0].tag.startswith("batch_"):
|
|
data[event.step] = self._process_event(event,
|
|
data.get(event.step, EventData()))
|
|
|
|
except tf_errors.DataLossError as err:
|
|
logger.warning("The logs for Session %s are corrupted and cannot be displayed. "
|
|
"The totals do not include this session. Original error message: "
|
|
"'%s'", session_id, str(err))
|
|
|
|
self._cache.cache_data(session_id, data, self._loss_labels, is_live=self._live_data)
|
|
|
|
def _parse_outputs(self, event: event_pb2.Event) -> None:
|
|
""" Parse the outputs from the stored model structure for mapping loss names to
|
|
model outputs.
|
|
|
|
Loss names are added to :attr:`_loss_labels`
|
|
|
|
Notes
|
|
-----
|
|
The master model does not actually contain the specified output name, so we dig into the
|
|
sub-model to obtain the name of the output layers
|
|
|
|
Parameters
|
|
----------
|
|
event: :class:`tensorflow.core.util.event_pb2`
|
|
The event data containing the keras model structure to be parsed
|
|
"""
|
|
serializer = get_serializer("json")
|
|
struct = event.summary.value[0].tensor.string_val[0]
|
|
|
|
config = serializer.unmarshal(struct)["config"]
|
|
model_outputs = self._get_outputs(config)
|
|
|
|
for side_outputs, side in zip(model_outputs, ("a", "b")):
|
|
logger.debug("side: '%s', outputs: '%s'", side, side_outputs)
|
|
layer_name = side_outputs[0][0]
|
|
|
|
output_config = next(layer for layer in config["layers"]
|
|
if layer["name"] == layer_name)["config"]
|
|
layer_outputs = self._get_outputs(output_config)
|
|
for output in layer_outputs: # Drill into sub-model to get the actual output names
|
|
loss_name = output[0][0]
|
|
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
|
|
new_name = f"{loss_name.replace('_both', '')}_{side}"
|
|
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
|
|
loss_name = new_name
|
|
if loss_name not in self._loss_labels:
|
|
logger.debug("Adding loss name: '%s'", loss_name)
|
|
self._loss_labels.append(loss_name)
|
|
logger.debug("Collated loss labels: %s", self._loss_labels)
|
|
|
|
@classmethod
|
|
def _get_outputs(cls, model_config: dict[str, T.Any]) -> np.ndarray:
|
|
""" Obtain the output names, instance index and output index for the given model.
|
|
|
|
If there is only a single output, the shape of the array is expanded to remain consistent
|
|
with multi model outputs
|
|
|
|
Parameters
|
|
----------
|
|
model_config: dict
|
|
The saved Keras model configuration dictionary
|
|
|
|
Returns
|
|
-------
|
|
:class:`numpy.ndarray`
|
|
The layer output names, their instance index and their output index
|
|
"""
|
|
outputs = np.array(model_config["output_layers"])
|
|
logger.debug("Obtained model outputs: %s, shape: %s", outputs, outputs.shape)
|
|
if outputs.ndim == 2: # Insert extra dimension for non learn mask models
|
|
outputs = np.expand_dims(outputs, axis=1)
|
|
logger.debug("Expanded dimensions for single output model. outputs: %s, shape: %s",
|
|
outputs, outputs.shape)
|
|
return outputs
|
|
|
|
@classmethod
|
|
def _process_event(cls, event: event_pb2.Event, step: EventData) -> EventData:
|
|
""" Process a single Tensorflow event.
|
|
|
|
Adds timestamp to the step `dict` if a total loss value is received, process the labels for
|
|
any new loss entries and adds the side loss value to the step `dict`.
|
|
|
|
Parameters
|
|
----------
|
|
event: :class:`tensorflow.core.util.event_pb2`
|
|
The event data to be processed
|
|
step: :class:`EventData`
|
|
The currently processing dictionary to be populated with the extracted data from the
|
|
tensorflow event for this step
|
|
|
|
Returns
|
|
-------
|
|
:class:`EventData`
|
|
The given step :class:`EventData` with the given event data added to it.
|
|
"""
|
|
summary = event.summary.value[0]
|
|
|
|
if summary.tag == "batch_total":
|
|
step.timestamp = event.wall_time
|
|
return step
|
|
|
|
loss = summary.simple_value
|
|
if not loss:
|
|
# Need to convert a tensor to a float for TF2.8 logged data. This maybe due to change
|
|
# in logging or may be due to work around put in place in FS training function for the
|
|
# following bug in TF 2.8/2.9 when writing records:
|
|
# https://github.com/keras-team/keras/issues/16173
|
|
loss = float(tf.make_ndarray(summary.tensor))
|
|
|
|
step.loss.append(loss)
|
|
|
|
return step
|