1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-09 04:36:50 -04:00
faceswap/lib/model/memory_saving_gradients.py
torzdf dba7d4162d
VRAM Improvement Options (#671)
* Implement ping-pong training

* Disable tensorboard for pingpong training

* Implement Memory Saving Gradients
2019-03-17 09:25:39 +00:00

439 lines
19 KiB
Python

#!/usr/bin/env python3
""" Memory saving gradients.
Adapted from: https://github.com/openai/gradient-checkpointing
The MIT License
Copyright (c) 2018 OpenAI (http://openai.com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import contextlib
import logging
import time
import sys
import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge # pylint: disable=no-name-in-module
from toposort import toposort
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
sys.setrecursionlimit(10000)
# refers back to current module if we decide to split helpers out
util = sys.modules[__name__]
# getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated"
setattr(tf.GraphKeys, "VARIABLES", "variables")
# save original gradients since tf.gradient could be monkey-patched to point
# to our version
from tensorflow.python.ops import gradients as tf_grads_lib # pylint: disable=no-name-in-module
tf_gradients = tf_grads_lib.gradients
MIN_CHECKPOINT_NODE_SIZE = 1024 # use lower value during testing
# specific versions we can use to do process-wide replacement of tf.gradients
def gradients_speed(ys, xs, grad_ys=None, **kwargs):
return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs)
def gradients_memory(ys, xs, grad_ys=None, **kwargs):
return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs)
def gradients_collection(ys, xs, grad_ys=None, **kwargs):
return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs)
def gradients(ys, xs, # pylint: disable: too-many-statements, too-many-branches
grad_ys=None, checkpoints='collection', **kwargs):
'''
Authors: Tim Salimans & Yaroslav Bulatov
memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory
Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)
ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
(https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)
'checkpoints' can either be
- a list consisting of tensors from the forward pass of the neural net
that we should re-use when calculating the gradients in the backward pass
all other tensors that do not appear in this list will be re-computed
- a string specifying how this list should be determined. currently we support
- 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually
the most expensive, so checkpointing them maximizes the running speed
(this is a good option if nonlinearities, concats, batchnorms, etc are
taking up a lot of memory)
- 'memory': try to minimize the memory usage
(currently using a very simple strategy that identifies a number of
bottleneck tensors in the graph to checkpoint)
- 'collection': look for a tensorflow collection named 'checkpoints', which holds the
tensors to checkpoint
'''
# print("Calling memsaving gradients with", checkpoints)
if not isinstance(ys, list):
ys = [ys]
if not isinstance(xs, list):
xs = [xs]
bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
inclusive=True)
debug_print("bwd_ops: {}".format(bwd_ops))
# forward ops are all ops that are candidates for recomputation
fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
inclusive=True,
within_ops=bwd_ops)
debug_print("fwd_ops: {}".format(fwd_ops))
# exclude ops with no inputs
fwd_ops = [op for op in fwd_ops if op.inputs]
# don't recompute xs, remove variables
xs_ops = _to_ops(xs)
fwd_ops = [op for op in fwd_ops if op not in xs_ops]
fwd_ops = [op for op in fwd_ops if '/assign' not in op.name]
fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name]
fwd_ops = [op for op in fwd_ops if '/read' not in op.name]
ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
ts_all = [t for t in ts_all if '/read' not in t.name]
ts_all = set(ts_all) - set(xs) - set(ys)
# construct list of tensors to checkpoint during forward pass, if not
# given as input
if type(checkpoints) is not list:
if checkpoints == 'collection':
checkpoints = tf.get_collection('checkpoints')
elif checkpoints == 'speed':
# checkpoint all expensive ops to maximize running speed
checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul')
elif checkpoints == 'memory':
# remove very small tensors and some weird ops
def fixdims(t): # tf.Dimension values are not compatible with int, convert manually
try:
return [int(e if e.value is not None else 64) for e in t]
except:
return [0] # unknown shape
ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
ts_all = [t for t in ts_all if 'entropy' not in t.name]
ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
ts_all = [t for t in ts_all if 'Switch' not in t.name]
ts_all = [t for t in ts_all if 'dropout' not in t.name]
# DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
ts_all = [t for t in ts_all if 'Cast' not in t.name]
# filter out all tensors that are inputs of the backward graph
with util.capture_ops() as bwd_ops:
tf_gradients(ys, xs, grad_ys, **kwargs)
bwd_inputs = [t for op in bwd_ops for t in op.inputs]
# list of tensors in forward graph that is in input to bwd graph
ts_filtered = list(set(bwd_inputs).intersection(ts_all))
debug_print("Using tensors {}".format(ts_filtered))
# try two slightly different ways of getting bottlenecks tensors
# to checkpoint
for ts in [ts_filtered, ts_all]:
# get all bottlenecks in the graph
bottleneck_ts = []
for t in ts:
b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
# check that there are not shortcuts
b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all)
f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all)
if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
bottleneck_ts.append(t) # we have a bottleneck!
else:
debug_print("Rejected bottleneck candidate and ops {}".format(
[t] + list(set(ts_all) - set(b_inp) - set(f_inp))))
# success? or try again without filtering?
if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # enough bottlenecks found!
break
if not bottleneck_ts:
raise Exception('unable to find bottleneck tensors! please provide checkpoint '
'nodes manually, or use checkpoints="speed".')
# sort the bottlenecks
bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops)
sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]
# save an approximately optimal number ~ sqrt(N)
N = len(ts_filtered)
if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
checkpoints = sorted_bottlenecks
else:
step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
checkpoints = sorted_bottlenecks[step::step]
else:
raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,))
checkpoints = list(set(checkpoints).intersection(ts_all))
# at this point automatic selection happened and checkpoints is list of nodes
assert isinstance(checkpoints, list)
debug_print("Checkpoint nodes used: {}".format(checkpoints))
# better error handling of special cases
# xs are already handled as checkpoint nodes, so no need to include them
xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
if xs_intersect_checkpoints:
debug_print("Warning, some input nodes are also checkpoint nodes: {}".format(
xs_intersect_checkpoints))
ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
debug_print("ys: {}, checkpoints:{}, intersect: {}".format(
ys, checkpoints, ys_intersect_checkpoints))
# saving an output node (ys) gives no benefit in memory while creating
# new edge cases, exclude them
if ys_intersect_checkpoints:
debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
format_ops(ys_intersect_checkpoints)))
# remove initial and terminal nodes from checkpoints list if present
checkpoints = list(set(checkpoints) - set(ys) - set(xs))
# check that we have some nodes to checkpoint
if not checkpoints:
raise Exception('no checkpoints nodes found or given as input! ')
# disconnect dependencies between checkpointed tensors
checkpoints_disconnected = {}
for x in checkpoints:
if x.op and x.op.name is not None:
grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
else:
grad_node = tf.stop_gradient(x)
checkpoints_disconnected[x] = grad_node
# partial derivatives to the checkpointed tensors and xs
ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
stop_at_ts=checkpoints, within_ops=fwd_ops)
debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
debug_print("ops_to_copy = {}".format(ops_to_copy))
debug_print("Processing list {}".format(ys))
_, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
for origin_op, op in info._transformed_ops.items():
op._set_device(origin_op.node_def.device)
copied_ops = info._transformed_ops.values()
debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
ge.reroute_ts(checkpoints_disconnected.values(),
checkpoints_disconnected.keys(),
can_modify=copied_ops)
debug_print("Rewired {} in place of {} restricted to {}".format(
checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))
# get gradients with respect to current boundary + original x's
copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
boundary = list(checkpoints_disconnected.values())
dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
debug_print("Got gradients {}".format(dv))
debug_print("for %s", copied_ys)
debug_print("with respect to {}".format(boundary+xs))
inputs_to_do_before = [y.op for y in ys]
if grad_ys is not None:
inputs_to_do_before += grad_ys
wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
# partial derivatives to the checkpointed nodes
# dictionary of "node: backprop" for nodes in the boundary
d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
dv[:len(checkpoints_disconnected)])}
# partial derivatives to xs (usually the params of the neural net)
d_xs = dv[len(checkpoints_disconnected):]
# incorporate derivatives flowing through the checkpointed nodes
checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
for ts in checkpoints_sorted_lists[::-1]:
debug_print("Processing list {}".format(ts))
checkpoints_other = [r for r in checkpoints if r not in ts]
checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]
# copy part of the graph below current checkpoint node, stopping at
# other checkpoints nodes
ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
seed_ops=[r.op for r in ts],
stop_at_ts=checkpoints_other)
debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format(
len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other))
debug_print("ops_to_copy = {}".format(ops_to_copy))
if not ops_to_copy: # we're done!
break
_, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
for origin_op, op in info._transformed_ops.items():
op._set_device(origin_op.node_def.device)
copied_ops = info._transformed_ops.values()
debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops)
debug_print("Rewired %s in place of %s restricted to %s",
checkpoints_disconnected_other, checkpoints_other, copied_ops)
# gradient flowing through the checkpointed node
boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
substitute_backprops = [d_checkpoints[r] for r in ts]
dv = tf_gradients(boundary,
checkpoints_disconnected_other+xs,
grad_ys=substitute_backprops, **kwargs)
debug_print("Got gradients {}".format(dv))
debug_print("for {}".format(boundary))
debug_print("with respect to {}".format(checkpoints_disconnected_other+xs))
debug_print("with boundary backprop substitutions {}".format(substitute_backprops))
inputs_to_do_before = [d_checkpoints[r].op for r in ts]
wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
# partial derivatives to the checkpointed nodes
for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
if dr is not None:
if d_checkpoints[r] is None:
d_checkpoints[r] = dr
else:
d_checkpoints[r] += dr
def _unsparsify(var_x):
if not isinstance(var_x, tf.IndexedSlices):
return var_x
assert var_x.dense_shape is not None, \
"memory_saving_gradients encountered sparse gradients of unknown shape"
indices = var_x.indices
while indices.shape.ndims < var_x.values.shape.ndims:
indices = tf.expand_dims(indices, -1)
return tf.scatter_nd(indices, var_x.values, var_x.dense_shape)
# partial derivatives to xs (usually the params of the neural net)
d_xs_new = dv[len(checkpoints_other):]
for j in range(len(xs)):
if d_xs_new[j] is not None:
if d_xs[j] is None:
d_xs[j] = _unsparsify(d_xs_new[j])
else:
d_xs[j] += _unsparsify(d_xs_new[j])
return d_xs
def tf_toposort(ts_inp, within_ops=None):
""" Tensorflow topological sort """
all_ops = ge.get_forward_walk_ops([x.op for x in ts_inp], within_ops=within_ops)
deps = {}
for tf_op in all_ops:
for outp in tf_op.outputs:
deps[outp] = set(tf_op.inputs)
sorted_ts = toposort(deps)
# only keep the tensors from our original list
ts_sorted_lists = []
for lst in sorted_ts:
keep = list(set(lst).intersection(ts_inp))
if keep:
ts_sorted_lists.append(keep)
return ts_sorted_lists
def fast_backward_ops(within_ops, seed_ops, stop_at_ts):
""" Fast backward ops """
bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts))
ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts])
return list(ops)
@contextlib.contextmanager
def capture_ops():
"""Decorator to capture ops created in the block.
with capture_ops() as ops:
# create some ops
print(ops) # => prints ops created.
"""
micros = int(time.time()*10**6)
scope_name = str(micros)
op_list = []
with tf.name_scope(scope_name):
yield op_list
graph = tf.get_default_graph()
op_list.extend(ge.select_ops(scope_name+"/.*", graph=graph))
def _to_op(tensor_or_op):
""" Convert to op """
if hasattr(tensor_or_op, "op"):
return tensor_or_op.op
return tensor_or_op
def _to_ops(iterable):
""" Convert to ops """
if not _is_iterable(iterable):
return iterable
return [_to_op(i) for i in iterable]
def _is_iterable(obj):
""" Check if object is iterable """
try:
_ = iter(obj)
except Exception: # pylint: disable=broad-except
return False
return True
def debug_print(msg, *args):
""" Debug logging """
formatted_args = [format_ops(arg) for arg in args]
logger.debug("%s: %s", msg, formatted_args)
def format_ops(ops, sort_outputs=True):
"""Helper method for printing ops. Converts Tensor/Operation op to op.name,
rest to str(op)."""
if hasattr(ops, '__iter__') and not isinstance(ops, str):
lst = [(op.name if hasattr(op, "name") else str(op)) for op in ops]
if sort_outputs:
return sorted(lst)
return lst
return ops.name if hasattr(ops, "name") else str(ops)
def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before):
""" Add control inputs """
for tf_op in wait_to_do_ops:
ctl_inp = [i for i in inputs_to_do_before
if tf_op.control_inputs is None or i not in tf_op.control_inputs]
ge.add_control_inputs(tf_op, ctl_inp)