#!/usr/bin python3 """ Stats functions for the GUI """ import time import os import warnings from math import ceil, sqrt import numpy as np from lib.Serializer import PickleSerializer class SavedSessions(object): """ Saved Training Session """ def __init__(self, sessions_data): self.serializer = PickleSerializer self.sessions = self.load_sessions(sessions_data) def load_sessions(self, filename): """ Load previously saved sessions """ stats = list() if os.path.isfile(filename): with open(filename, self.serializer.roptions) as sessions: stats = self.serializer.unmarshal(sessions.read()) return stats def save_sessions(self, filename): """ Save the session file """ with open(filename, self.serializer.woptions) as session: session.write(self.serializer.marshal(self.sessions)) print("Saved session stats to: {}".format(filename)) class CurrentSession(object): """ The current training session """ def __init__(self): self.stats = {"iterations": 0, "batchsize": None, # Set and reset by wrapper "timestamps": [], "loss": [], "losskeys": []} self.timestats = {"start": None, "elapsed": None} self.modeldir = None # Set and reset by wrapper self.filename = None self.historical = None def initialise_session(self, currentloss): """ Initialise the training session """ self.load_historical() for item in currentloss: self.stats["losskeys"].append(item[0]) self.stats["loss"].append(list()) self.timestats["start"] = time.time() def load_historical(self): """ Load historical data and add current session to the end """ self.filename = os.path.join(self.modeldir, "trainingstats.fss") self.historical = SavedSessions(self.filename) self.historical.sessions.append(self.stats) def add_loss(self, currentloss): """ Add a loss item from the training process """ if self.stats["iterations"] == 0: self.initialise_session(currentloss) self.stats["iterations"] += 1 self.add_timestats() for idx, item in enumerate(currentloss): self.stats["loss"][idx].append(float(item[1])) def add_timestats(self): """ Add timestats to loss dict and timestats """ now = time.time() self.stats["timestamps"].append(now) elapsed_time = now - self.timestats["start"] self.timestats["elapsed"] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) def save_session(self): """ Save the session file to the modeldir """ if self.stats["iterations"] > 0: print("Saving session stats...") self.historical.save_sessions(self.filename) class SessionsTotals(object): """ The compiled totals of all saved sessions """ def __init__(self, all_sessions): self.stats = {"split": [], "iterations": 0, "batchsize": [], "timestamps": [], "loss": [], "losskeys": []} self.initiate(all_sessions) self.compile(all_sessions) def initiate(self, sessions): """ Initiate correct losskey titles and number of loss lists """ for losskey in sessions[0]["losskeys"]: self.stats["losskeys"].append(losskey) self.stats["loss"].append(list()) def compile(self, sessions): """ Compile all of the sessions into totals """ current_split = 0 for session in sessions: iterations = session["iterations"] current_split += iterations self.stats["split"].append(current_split) self.stats["iterations"] += iterations self.stats["timestamps"].extend(session["timestamps"]) self.stats["batchsize"].append(session["batchsize"]) self.add_loss(session["loss"]) def add_loss(self, session_loss): """ Add loss vals to each of their respective lists """ for idx, loss in enumerate(session_loss): self.stats["loss"][idx].extend(loss) class SessionsSummary(object): """ Calculations for analysis summary stats """ def __init__(self, raw_data): self.summary = list() self.summary_stats_compile(raw_data) def summary_stats_compile(self, raw_data): """ Compile summary stats """ raw_summaries = list() for idx, session in enumerate(raw_data): raw_summaries.append(self.summarise_session(idx, session)) totals_summary = self.summarise_totals(raw_summaries) raw_summaries.append(totals_summary) self.format_summaries(raw_summaries) # Compile Session Summaries @staticmethod def summarise_session(idx, session): """ Compile stats for session passed in """ starttime = session["timestamps"][0] endtime = session["timestamps"][-1] elapsed = endtime - starttime # Bump elapsed to 0.1s if no time is recorded # to hack around div by zero error elapsed = 0.1 if elapsed == 0 else elapsed rate = (session["batchsize"] * session["iterations"]) / elapsed return {"session": idx + 1, "start": starttime, "end": endtime, "elapsed": elapsed, "rate": rate, "batch": session["batchsize"], "iterations": session["iterations"]} @staticmethod def summarise_totals(raw_summaries): """ Compile the stats for all sessions combined """ elapsed = 0 rate = 0 batchset = set() iterations = 0 total_summaries = len(raw_summaries) for idx, summary in enumerate(raw_summaries): if idx == 0: starttime = summary["start"] if idx == total_summaries - 1: endtime = summary["end"] elapsed += summary["elapsed"] rate += summary["rate"] batchset.add(summary["batch"]) iterations += summary["iterations"] batch = ",".join(str(bs) for bs in batchset) return {"session": "Total", "start": starttime, "end": endtime, "elapsed": elapsed, "rate": rate / total_summaries, "batch": batch, "iterations": iterations} def format_summaries(self, raw_summaries): """ Format the summaries nicely for display """ for summary in raw_summaries: summary["start"] = time.strftime("%x %X", time.gmtime(summary["start"])) summary["end"] = time.strftime("%x %X", time.gmtime(summary["end"])) summary["elapsed"] = time.strftime("%H:%M:%S", time.gmtime(summary["elapsed"])) summary["rate"] = "{0:.1f}".format(summary["rate"]) self.summary = raw_summaries class Calculations(object): """ Class to hold calculations against raw session data """ def __init__(self, session, display="loss", selections=["raw"], avg_samples=10, flatten_outliers=False, is_totals=False): warnings.simplefilter("ignore", np.RankWarning) self.session = session if display.lower() == "loss": display = self.session["losskeys"] else: display = [display] self.args = {"display": display, "selections": selections, "avg_samples": int(avg_samples), "flatten_outliers": flatten_outliers, "is_totals": is_totals} self.iterations = 0 self.stats = None self.refresh() def refresh(self): """ Refresh the stats """ self.iterations = 0 self.stats = self.get_raw() self.get_calculations() self.remove_raw() def get_raw(self): """ Add raw data to stats dict """ raw = dict() for idx, item in enumerate(self.args["display"]): if item.lower() == "rate": data = self.calc_rate(self.session) else: data = self.session["loss"][idx][:] if self.args["flatten_outliers"]: data = self.flatten_outliers(data) if self.iterations == 0: self.iterations = len(data) raw["raw_{}".format(item)] = data return raw def remove_raw(self): """ Remove raw values from stats if not requested """ if "raw" in self.args["selections"]: return for key in list(self.stats.keys()): if key.startswith("raw"): del self.stats[key] def calc_rate(self, data): """ Calculate rate per iteration NB: For totals, gaps between sessions can be large so time diffeence has to be reset for each session's rate calculation """ batchsize = data["batchsize"] if self.args["is_totals"]: split = data["split"] else: batchsize = [batchsize] split = [len(data["timestamps"])] prev_split = 0 rate = list() for idx, current_split in enumerate(split): prev_time = data["timestamps"][prev_split] timestamp_chunk = data["timestamps"][prev_split:current_split] for item in timestamp_chunk: current_time = item timediff = current_time - prev_time iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff rate.append(iter_rate) prev_time = current_time prev_split = current_split if self.args["flatten_outliers"]: rate = self.flatten_outliers(rate) return rate @staticmethod def flatten_outliers(data): """ Remove the outliers from a provided list """ retdata = list() samples = len(data) mean = (sum(data) / samples) limit = sqrt(sum([(item - mean)**2 for item in data]) / samples) for item in data: if (mean - limit) <= item <= (mean + limit): retdata.append(item) else: retdata.append(mean) return retdata def get_calculations(self): """ Perform the required calculations """ for selection in self.get_selections(): if selection[0] == "raw": continue method = getattr(self, "calc_{}".format(selection[0])) key = "{}_{}".format(selection[0], selection[1]) raw = self.stats["raw_{}".format(selection[1])] self.stats[key] = method(raw) def get_selections(self): """ Compile a list of data to be calculated """ for summary in self.args["selections"]: for item in self.args["display"]: yield summary, item def calc_avg(self, data): """ Calculate rolling average """ avgs = list() presample = ceil(self.args["avg_samples"] / 2) postsample = self.args["avg_samples"] - presample datapoints = len(data) if datapoints <= (self.args["avg_samples"] * 2): print("Not enough data to compile rolling average") return avgs for idx in range(0, datapoints): if idx < presample or idx >= datapoints - postsample: avgs.append(None) continue else: avg = sum(data[idx - presample:idx + postsample]) \ / self.args["avg_samples"] avgs.append(avg) return avgs @staticmethod def calc_trend(data): """ Compile trend data """ points = len(data) if points < 10: dummy = [None for i in range(points)] return dummy x_range = range(points) fit = np.polyfit(x_range, data, 3) poly = np.poly1d(fit) trend = poly(x_range) return trend