1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/lib/gui/stats.py

356 lines
12 KiB
Python

#!/usr/bin python3
""" Stats functions for the GUI """
import time
import os
import warnings
from math import ceil, sqrt
import numpy as np
from lib.Serializer import PickleSerializer
class SavedSessions(object):
""" Saved Training Session """
def __init__(self, sessions_data):
self.serializer = PickleSerializer
self.sessions = self.load_sessions(sessions_data)
def load_sessions(self, filename):
""" Load previously saved sessions """
stats = list()
if os.path.isfile(filename):
with open(filename, self.serializer.roptions) as sessions:
stats = self.serializer.unmarshal(sessions.read())
return stats
def save_sessions(self, filename):
""" Save the session file """
with open(filename, self.serializer.woptions) as session:
session.write(self.serializer.marshal(self.sessions))
print("Saved session stats to: {}".format(filename))
class CurrentSession(object):
""" The current training session """
def __init__(self):
self.stats = {"iterations": 0,
"batchsize": None, # Set and reset by wrapper
"timestamps": [],
"loss": [],
"losskeys": []}
self.timestats = {"start": None,
"elapsed": None}
self.modeldir = None # Set and reset by wrapper
self.filename = None
self.historical = None
def initialise_session(self, currentloss):
""" Initialise the training session """
self.load_historical()
for item in currentloss:
self.stats["losskeys"].append(item[0])
self.stats["loss"].append(list())
self.timestats["start"] = time.time()
def load_historical(self):
""" Load historical data and add current session to the end """
self.filename = os.path.join(self.modeldir, "trainingstats.fss")
self.historical = SavedSessions(self.filename)
self.historical.sessions.append(self.stats)
def add_loss(self, currentloss):
""" Add a loss item from the training process """
if self.stats["iterations"] == 0:
self.initialise_session(currentloss)
self.stats["iterations"] += 1
self.add_timestats()
for idx, item in enumerate(currentloss):
self.stats["loss"][idx].append(float(item[1]))
def add_timestats(self):
""" Add timestats to loss dict and timestats """
now = time.time()
self.stats["timestamps"].append(now)
elapsed_time = now - self.timestats["start"]
self.timestats["elapsed"] = time.strftime("%H:%M:%S",
time.gmtime(elapsed_time))
def save_session(self):
""" Save the session file to the modeldir """
if self.stats["iterations"] > 0:
print("Saving session stats...")
self.historical.save_sessions(self.filename)
class SessionsTotals(object):
""" The compiled totals of all saved sessions """
def __init__(self, all_sessions):
self.stats = {"split": [],
"iterations": 0,
"batchsize": [],
"timestamps": [],
"loss": [],
"losskeys": []}
self.initiate(all_sessions)
self.compile(all_sessions)
def initiate(self, sessions):
""" Initiate correct losskey titles and number of loss lists """
for losskey in sessions[0]["losskeys"]:
self.stats["losskeys"].append(losskey)
self.stats["loss"].append(list())
def compile(self, sessions):
""" Compile all of the sessions into totals """
current_split = 0
for session in sessions:
iterations = session["iterations"]
current_split += iterations
self.stats["split"].append(current_split)
self.stats["iterations"] += iterations
self.stats["timestamps"].extend(session["timestamps"])
self.stats["batchsize"].append(session["batchsize"])
self.add_loss(session["loss"])
def add_loss(self, session_loss):
""" Add loss vals to each of their respective lists """
for idx, loss in enumerate(session_loss):
self.stats["loss"][idx].extend(loss)
class SessionsSummary(object):
""" Calculations for analysis summary stats """
def __init__(self, raw_data):
self.summary = list()
self.summary_stats_compile(raw_data)
def summary_stats_compile(self, raw_data):
""" Compile summary stats """
raw_summaries = list()
for idx, session in enumerate(raw_data):
raw_summaries.append(self.summarise_session(idx, session))
totals_summary = self.summarise_totals(raw_summaries)
raw_summaries.append(totals_summary)
self.format_summaries(raw_summaries)
# Compile Session Summaries
@staticmethod
def summarise_session(idx, session):
""" Compile stats for session passed in """
starttime = session["timestamps"][0]
endtime = session["timestamps"][-1]
elapsed = endtime - starttime
# Bump elapsed to 0.1s if no time is recorded
# to hack around div by zero error
elapsed = 0.1 if elapsed == 0 else elapsed
rate = (session["batchsize"] * session["iterations"]) / elapsed
return {"session": idx + 1,
"start": starttime,
"end": endtime,
"elapsed": elapsed,
"rate": rate,
"batch": session["batchsize"],
"iterations": session["iterations"]}
@staticmethod
def summarise_totals(raw_summaries):
""" Compile the stats for all sessions combined """
elapsed = 0
rate = 0
batchset = set()
iterations = 0
total_summaries = len(raw_summaries)
for idx, summary in enumerate(raw_summaries):
if idx == 0:
starttime = summary["start"]
if idx == total_summaries - 1:
endtime = summary["end"]
elapsed += summary["elapsed"]
rate += summary["rate"]
batchset.add(summary["batch"])
iterations += summary["iterations"]
batch = ",".join(str(bs) for bs in batchset)
return {"session": "Total",
"start": starttime,
"end": endtime,
"elapsed": elapsed,
"rate": rate / total_summaries,
"batch": batch,
"iterations": iterations}
def format_summaries(self, raw_summaries):
""" Format the summaries nicely for display """
for summary in raw_summaries:
summary["start"] = time.strftime("%x %X",
time.gmtime(summary["start"]))
summary["end"] = time.strftime("%x %X",
time.gmtime(summary["end"]))
summary["elapsed"] = time.strftime("%H:%M:%S",
time.gmtime(summary["elapsed"]))
summary["rate"] = "{0:.1f}".format(summary["rate"])
self.summary = raw_summaries
class Calculations(object):
""" Class to hold calculations against raw session data """
def __init__(self,
session,
display="loss",
selections=["raw"],
avg_samples=10,
flatten_outliers=False,
is_totals=False):
warnings.simplefilter("ignore", np.RankWarning)
self.session = session
if display.lower() == "loss":
display = self.session["losskeys"]
else:
display = [display]
self.args = {"display": display,
"selections": selections,
"avg_samples": int(avg_samples),
"flatten_outliers": flatten_outliers,
"is_totals": is_totals}
self.iterations = 0
self.stats = None
self.refresh()
def refresh(self):
""" Refresh the stats """
self.iterations = 0
self.stats = self.get_raw()
self.get_calculations()
self.remove_raw()
def get_raw(self):
""" Add raw data to stats dict """
raw = dict()
for idx, item in enumerate(self.args["display"]):
if item.lower() == "rate":
data = self.calc_rate(self.session)
else:
data = self.session["loss"][idx][:]
if self.args["flatten_outliers"]:
data = self.flatten_outliers(data)
if self.iterations == 0:
self.iterations = len(data)
raw["raw_{}".format(item)] = data
return raw
def remove_raw(self):
""" Remove raw values from stats if not requested """
if "raw" in self.args["selections"]:
return
for key in list(self.stats.keys()):
if key.startswith("raw"):
del self.stats[key]
def calc_rate(self, data):
""" Calculate rate per iteration
NB: For totals, gaps between sessions can be large
so time diffeence has to be reset for each session's
rate calculation """
batchsize = data["batchsize"]
if self.args["is_totals"]:
split = data["split"]
else:
batchsize = [batchsize]
split = [len(data["timestamps"])]
prev_split = 0
rate = list()
for idx, current_split in enumerate(split):
prev_time = data["timestamps"][prev_split]
timestamp_chunk = data["timestamps"][prev_split:current_split]
for item in timestamp_chunk:
current_time = item
timediff = current_time - prev_time
iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff
rate.append(iter_rate)
prev_time = current_time
prev_split = current_split
if self.args["flatten_outliers"]:
rate = self.flatten_outliers(rate)
return rate
@staticmethod
def flatten_outliers(data):
""" Remove the outliers from a provided list """
retdata = list()
samples = len(data)
mean = (sum(data) / samples)
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
for item in data:
if (mean - limit) <= item <= (mean + limit):
retdata.append(item)
else:
retdata.append(mean)
return retdata
def get_calculations(self):
""" Perform the required calculations """
for selection in self.get_selections():
if selection[0] == "raw":
continue
method = getattr(self, "calc_{}".format(selection[0]))
key = "{}_{}".format(selection[0], selection[1])
raw = self.stats["raw_{}".format(selection[1])]
self.stats[key] = method(raw)
def get_selections(self):
""" Compile a list of data to be calculated """
for summary in self.args["selections"]:
for item in self.args["display"]:
yield summary, item
def calc_avg(self, data):
""" Calculate rolling average """
avgs = list()
presample = ceil(self.args["avg_samples"] / 2)
postsample = self.args["avg_samples"] - presample
datapoints = len(data)
if datapoints <= (self.args["avg_samples"] * 2):
print("Not enough data to compile rolling average")
return avgs
for idx in range(0, datapoints):
if idx < presample or idx >= datapoints - postsample:
avgs.append(None)
continue
else:
avg = sum(data[idx - presample:idx + postsample]) \
/ self.args["avg_samples"]
avgs.append(avg)
return avgs
@staticmethod
def calc_trend(data):
""" Compile trend data """
points = len(data)
if points < 10:
dummy = [None for i in range(points)]
return dummy
x_range = range(points)
fit = np.polyfit(x_range, data, 3)
poly = np.poly1d(fit)
trend = poly(x_range)
return trend