mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-09 04:36:50 -04:00
356 lines
12 KiB
Python
356 lines
12 KiB
Python
#!/usr/bin python3
|
|
""" Stats functions for the GUI """
|
|
|
|
import time
|
|
import os
|
|
import warnings
|
|
|
|
from math import ceil, sqrt
|
|
|
|
import numpy as np
|
|
|
|
from lib.Serializer import PickleSerializer
|
|
|
|
|
|
class SavedSessions(object):
|
|
""" Saved Training Session """
|
|
def __init__(self, sessions_data):
|
|
self.serializer = PickleSerializer
|
|
self.sessions = self.load_sessions(sessions_data)
|
|
|
|
def load_sessions(self, filename):
|
|
""" Load previously saved sessions """
|
|
stats = list()
|
|
if os.path.isfile(filename):
|
|
with open(filename, self.serializer.roptions) as sessions:
|
|
stats = self.serializer.unmarshal(sessions.read())
|
|
return stats
|
|
|
|
def save_sessions(self, filename):
|
|
""" Save the session file """
|
|
with open(filename, self.serializer.woptions) as session:
|
|
session.write(self.serializer.marshal(self.sessions))
|
|
print("Saved session stats to: {}".format(filename))
|
|
|
|
|
|
class CurrentSession(object):
|
|
""" The current training session """
|
|
def __init__(self):
|
|
self.stats = {"iterations": 0,
|
|
"batchsize": None, # Set and reset by wrapper
|
|
"timestamps": [],
|
|
"loss": [],
|
|
"losskeys": []}
|
|
self.timestats = {"start": None,
|
|
"elapsed": None}
|
|
self.modeldir = None # Set and reset by wrapper
|
|
self.filename = None
|
|
self.historical = None
|
|
|
|
def initialise_session(self, currentloss):
|
|
""" Initialise the training session """
|
|
self.load_historical()
|
|
for item in currentloss:
|
|
self.stats["losskeys"].append(item[0])
|
|
self.stats["loss"].append(list())
|
|
self.timestats["start"] = time.time()
|
|
|
|
def load_historical(self):
|
|
""" Load historical data and add current session to the end """
|
|
self.filename = os.path.join(self.modeldir, "trainingstats.fss")
|
|
self.historical = SavedSessions(self.filename)
|
|
self.historical.sessions.append(self.stats)
|
|
|
|
def add_loss(self, currentloss):
|
|
""" Add a loss item from the training process """
|
|
if self.stats["iterations"] == 0:
|
|
self.initialise_session(currentloss)
|
|
|
|
self.stats["iterations"] += 1
|
|
self.add_timestats()
|
|
|
|
for idx, item in enumerate(currentloss):
|
|
self.stats["loss"][idx].append(float(item[1]))
|
|
|
|
def add_timestats(self):
|
|
""" Add timestats to loss dict and timestats """
|
|
now = time.time()
|
|
self.stats["timestamps"].append(now)
|
|
elapsed_time = now - self.timestats["start"]
|
|
self.timestats["elapsed"] = time.strftime("%H:%M:%S",
|
|
time.gmtime(elapsed_time))
|
|
|
|
def save_session(self):
|
|
""" Save the session file to the modeldir """
|
|
if self.stats["iterations"] > 0:
|
|
print("Saving session stats...")
|
|
self.historical.save_sessions(self.filename)
|
|
|
|
|
|
class SessionsTotals(object):
|
|
""" The compiled totals of all saved sessions """
|
|
def __init__(self, all_sessions):
|
|
self.stats = {"split": [],
|
|
"iterations": 0,
|
|
"batchsize": [],
|
|
"timestamps": [],
|
|
"loss": [],
|
|
"losskeys": []}
|
|
|
|
self.initiate(all_sessions)
|
|
self.compile(all_sessions)
|
|
|
|
def initiate(self, sessions):
|
|
""" Initiate correct losskey titles and number of loss lists """
|
|
for losskey in sessions[0]["losskeys"]:
|
|
self.stats["losskeys"].append(losskey)
|
|
self.stats["loss"].append(list())
|
|
|
|
def compile(self, sessions):
|
|
""" Compile all of the sessions into totals """
|
|
current_split = 0
|
|
for session in sessions:
|
|
iterations = session["iterations"]
|
|
current_split += iterations
|
|
self.stats["split"].append(current_split)
|
|
self.stats["iterations"] += iterations
|
|
self.stats["timestamps"].extend(session["timestamps"])
|
|
self.stats["batchsize"].append(session["batchsize"])
|
|
self.add_loss(session["loss"])
|
|
|
|
def add_loss(self, session_loss):
|
|
""" Add loss vals to each of their respective lists """
|
|
for idx, loss in enumerate(session_loss):
|
|
self.stats["loss"][idx].extend(loss)
|
|
|
|
|
|
class SessionsSummary(object):
|
|
""" Calculations for analysis summary stats """
|
|
|
|
def __init__(self, raw_data):
|
|
self.summary = list()
|
|
self.summary_stats_compile(raw_data)
|
|
|
|
def summary_stats_compile(self, raw_data):
|
|
""" Compile summary stats """
|
|
raw_summaries = list()
|
|
for idx, session in enumerate(raw_data):
|
|
raw_summaries.append(self.summarise_session(idx, session))
|
|
|
|
totals_summary = self.summarise_totals(raw_summaries)
|
|
raw_summaries.append(totals_summary)
|
|
self.format_summaries(raw_summaries)
|
|
|
|
# Compile Session Summaries
|
|
@staticmethod
|
|
def summarise_session(idx, session):
|
|
""" Compile stats for session passed in """
|
|
starttime = session["timestamps"][0]
|
|
endtime = session["timestamps"][-1]
|
|
elapsed = endtime - starttime
|
|
# Bump elapsed to 0.1s if no time is recorded
|
|
# to hack around div by zero error
|
|
elapsed = 0.1 if elapsed == 0 else elapsed
|
|
rate = (session["batchsize"] * session["iterations"]) / elapsed
|
|
return {"session": idx + 1,
|
|
"start": starttime,
|
|
"end": endtime,
|
|
"elapsed": elapsed,
|
|
"rate": rate,
|
|
"batch": session["batchsize"],
|
|
"iterations": session["iterations"]}
|
|
|
|
@staticmethod
|
|
def summarise_totals(raw_summaries):
|
|
""" Compile the stats for all sessions combined """
|
|
elapsed = 0
|
|
rate = 0
|
|
batchset = set()
|
|
iterations = 0
|
|
total_summaries = len(raw_summaries)
|
|
|
|
for idx, summary in enumerate(raw_summaries):
|
|
if idx == 0:
|
|
starttime = summary["start"]
|
|
if idx == total_summaries - 1:
|
|
endtime = summary["end"]
|
|
elapsed += summary["elapsed"]
|
|
rate += summary["rate"]
|
|
batchset.add(summary["batch"])
|
|
iterations += summary["iterations"]
|
|
batch = ",".join(str(bs) for bs in batchset)
|
|
|
|
return {"session": "Total",
|
|
"start": starttime,
|
|
"end": endtime,
|
|
"elapsed": elapsed,
|
|
"rate": rate / total_summaries,
|
|
"batch": batch,
|
|
"iterations": iterations}
|
|
|
|
def format_summaries(self, raw_summaries):
|
|
""" Format the summaries nicely for display """
|
|
for summary in raw_summaries:
|
|
summary["start"] = time.strftime("%x %X",
|
|
time.gmtime(summary["start"]))
|
|
summary["end"] = time.strftime("%x %X",
|
|
time.gmtime(summary["end"]))
|
|
summary["elapsed"] = time.strftime("%H:%M:%S",
|
|
time.gmtime(summary["elapsed"]))
|
|
summary["rate"] = "{0:.1f}".format(summary["rate"])
|
|
self.summary = raw_summaries
|
|
|
|
|
|
class Calculations(object):
|
|
""" Class to hold calculations against raw session data """
|
|
def __init__(self,
|
|
session,
|
|
display="loss",
|
|
selections=["raw"],
|
|
avg_samples=10,
|
|
flatten_outliers=False,
|
|
is_totals=False):
|
|
|
|
warnings.simplefilter("ignore", np.RankWarning)
|
|
|
|
self.session = session
|
|
if display.lower() == "loss":
|
|
display = self.session["losskeys"]
|
|
else:
|
|
display = [display]
|
|
self.args = {"display": display,
|
|
"selections": selections,
|
|
"avg_samples": int(avg_samples),
|
|
"flatten_outliers": flatten_outliers,
|
|
"is_totals": is_totals}
|
|
self.iterations = 0
|
|
self.stats = None
|
|
self.refresh()
|
|
|
|
def refresh(self):
|
|
""" Refresh the stats """
|
|
self.iterations = 0
|
|
self.stats = self.get_raw()
|
|
self.get_calculations()
|
|
self.remove_raw()
|
|
|
|
def get_raw(self):
|
|
""" Add raw data to stats dict """
|
|
raw = dict()
|
|
for idx, item in enumerate(self.args["display"]):
|
|
if item.lower() == "rate":
|
|
data = self.calc_rate(self.session)
|
|
else:
|
|
data = self.session["loss"][idx][:]
|
|
|
|
if self.args["flatten_outliers"]:
|
|
data = self.flatten_outliers(data)
|
|
|
|
if self.iterations == 0:
|
|
self.iterations = len(data)
|
|
|
|
raw["raw_{}".format(item)] = data
|
|
return raw
|
|
|
|
def remove_raw(self):
|
|
""" Remove raw values from stats if not requested """
|
|
if "raw" in self.args["selections"]:
|
|
return
|
|
for key in list(self.stats.keys()):
|
|
if key.startswith("raw"):
|
|
del self.stats[key]
|
|
|
|
def calc_rate(self, data):
|
|
""" Calculate rate per iteration
|
|
NB: For totals, gaps between sessions can be large
|
|
so time diffeence has to be reset for each session's
|
|
rate calculation """
|
|
batchsize = data["batchsize"]
|
|
if self.args["is_totals"]:
|
|
split = data["split"]
|
|
else:
|
|
batchsize = [batchsize]
|
|
split = [len(data["timestamps"])]
|
|
|
|
prev_split = 0
|
|
rate = list()
|
|
|
|
for idx, current_split in enumerate(split):
|
|
prev_time = data["timestamps"][prev_split]
|
|
timestamp_chunk = data["timestamps"][prev_split:current_split]
|
|
for item in timestamp_chunk:
|
|
current_time = item
|
|
timediff = current_time - prev_time
|
|
iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff
|
|
rate.append(iter_rate)
|
|
prev_time = current_time
|
|
prev_split = current_split
|
|
|
|
if self.args["flatten_outliers"]:
|
|
rate = self.flatten_outliers(rate)
|
|
return rate
|
|
|
|
@staticmethod
|
|
def flatten_outliers(data):
|
|
""" Remove the outliers from a provided list """
|
|
retdata = list()
|
|
samples = len(data)
|
|
mean = (sum(data) / samples)
|
|
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
|
|
|
|
for item in data:
|
|
if (mean - limit) <= item <= (mean + limit):
|
|
retdata.append(item)
|
|
else:
|
|
retdata.append(mean)
|
|
return retdata
|
|
|
|
def get_calculations(self):
|
|
""" Perform the required calculations """
|
|
for selection in self.get_selections():
|
|
if selection[0] == "raw":
|
|
continue
|
|
method = getattr(self, "calc_{}".format(selection[0]))
|
|
key = "{}_{}".format(selection[0], selection[1])
|
|
raw = self.stats["raw_{}".format(selection[1])]
|
|
self.stats[key] = method(raw)
|
|
|
|
def get_selections(self):
|
|
""" Compile a list of data to be calculated """
|
|
for summary in self.args["selections"]:
|
|
for item in self.args["display"]:
|
|
yield summary, item
|
|
|
|
def calc_avg(self, data):
|
|
""" Calculate rolling average """
|
|
avgs = list()
|
|
presample = ceil(self.args["avg_samples"] / 2)
|
|
postsample = self.args["avg_samples"] - presample
|
|
datapoints = len(data)
|
|
|
|
if datapoints <= (self.args["avg_samples"] * 2):
|
|
print("Not enough data to compile rolling average")
|
|
return avgs
|
|
|
|
for idx in range(0, datapoints):
|
|
if idx < presample or idx >= datapoints - postsample:
|
|
avgs.append(None)
|
|
continue
|
|
else:
|
|
avg = sum(data[idx - presample:idx + postsample]) \
|
|
/ self.args["avg_samples"]
|
|
avgs.append(avg)
|
|
return avgs
|
|
|
|
@staticmethod
|
|
def calc_trend(data):
|
|
""" Compile trend data """
|
|
points = len(data)
|
|
if points < 10:
|
|
dummy = [None for i in range(points)]
|
|
return dummy
|
|
x_range = range(points)
|
|
fit = np.polyfit(x_range, data, 3)
|
|
poly = np.poly1d(fit)
|
|
trend = poly(x_range)
|
|
return trend
|