-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
Copy pathoptimizer.py
executable file
·214 lines (169 loc) · 9.75 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helper wrapper for a Tensorflow optimizer."""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import List, Union
from . import autosummary
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
try:
# TensorFlow 1.13
from tensorflow.python.ops import nccl_ops
except:
# Older TensorFlow versions
import tensorflow.contrib.nccl as nccl_ops
class Optimizer:
"""A Wrapper for tf.train.Optimizer.
Automatically takes care of:
- Gradient averaging for multi-GPU training.
- Dynamic loss scaling and typecasts for FP16 training.
- Ignoring corrupted gradients that contain NaNs/Infs.
- Reporting statistics.
- Well-chosen default settings.
"""
def __init__(self,
name: str = "Train",
tf_optimizer: str = "tf.train.AdamOptimizer",
learning_rate: TfExpressionEx = 0.001,
use_loss_scaling: bool = False,
loss_scaling_init: float = 64.0,
loss_scaling_inc: float = 0.0005,
loss_scaling_dec: float = 1.0,
**kwargs):
# Init fields.
self.name = name
self.learning_rate = tf.convert_to_tensor(learning_rate)
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
self.optimizer_kwargs = dict(kwargs)
self.use_loss_scaling = use_loss_scaling
self.loss_scaling_init = loss_scaling_init
self.loss_scaling_inc = loss_scaling_inc
self.loss_scaling_dec = loss_scaling_dec
self._grad_shapes = None # [shape, ...]
self._dev_opt = OrderedDict() # device => optimizer
self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
self._updates_applied = False
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
"""Register the gradients of the given loss function with respect to the given variables.
Intended to be called once per GPU."""
assert not self._updates_applied
# Validate arguments.
if isinstance(trainable_vars, dict):
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
if self._grad_shapes is None:
self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
assert len(trainable_vars) == len(self._grad_shapes)
assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
dev = loss.device
assert all(var.device == dev for var in trainable_vars)
# Register device and compute gradients.
with tf.name_scope(self.id + "_grad"), tf.device(dev):
if dev not in self._dev_opt:
opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
assert callable(self.optimizer_class)
self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
self._dev_grads[dev] = []
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
self._dev_grads[dev].append(grads)
def apply_updates(self) -> tf.Operation:
"""Construct training op to update the registered variables based on their gradients."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
self._updates_applied = True
devices = list(self._dev_grads.keys())
total_grads = sum(len(grads) for grads in self._dev_grads.values())
assert len(devices) >= 1 and total_grads >= 1
ops = []
with tfutil.absolute_name_scope(self.scope):
# Cast gradients to FP32 and calculate partial sum within each device.
dev_grads = OrderedDict() # device => [(grad, var), ...]
for dev_idx, dev in enumerate(devices):
with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
sums = []
for gv in zip(*self._dev_grads[dev]):
assert all(v is gv[0][1] for g, v in gv)
g = [tf.cast(g, tf.float32) for g, v in gv]
g = g[0] if len(g) == 1 else tf.add_n(g)
sums.append((g, gv[0][1]))
dev_grads[dev] = sums
# Sum gradients across devices.
if len(devices) > 1:
with tf.name_scope("SumAcrossGPUs"), tf.device(None):
for var_idx, grad_shape in enumerate(self._grad_shapes):
g = [dev_grads[dev][var_idx][0] for dev in devices]
if np.prod(grad_shape): # nccl does not support zero-sized tensors
g = nccl_ops.all_sum(g)
for dev, gg in zip(devices, g):
dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
# Apply updates separately on each device.
for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
# Scale gradients as needed.
if self.use_loss_scaling or total_grads > 1:
with tf.name_scope("Scale"):
coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
coef = self.undo_loss_scaling(coef)
grads = [(g * coef, v) for g, v in grads]
# Check for overflows.
with tf.name_scope("CheckOverflow"):
grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
# Update weights and adjust loss scaling.
with tf.name_scope("UpdateWeights"):
# pylint: disable=cell-var-from-loop
opt = self._dev_opt[dev]
ls_var = self.get_loss_scaling_var(dev)
if not self.use_loss_scaling:
ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
else:
ops.append(tf.cond(grad_ok,
lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
# Report statistics on the last device.
if dev == devices[-1]:
with tf.name_scope("Statistics"):
ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
if self.use_loss_scaling:
ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
# Initialize variables and group everything into a single op.
self.reset_optimizer_state()
tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
return tf.group(*ops, name="TrainingOp")
def reset_optimizer_state(self) -> None:
"""Reset internal state of the underlying optimizer."""
tfutil.assert_tf_initialized()
tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
if not self.use_loss_scaling:
return None
if device not in self._dev_ls_var:
with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
return self._dev_ls_var[device]
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Apply dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Undo the effect of dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type