Tensorflow:不要更新如果渐变是南
问题描述:
我有一个深的模型来训练CIFAR-10。培训可以在CPU中正常工作。但是,当我使用GPU支持时,它会导致某些批次的梯度变为NaN(我使用tf.check_numerics
进行了检查),并且它随机发生,但足够早。我相信这个问题与我的GPU有关。Tensorflow:不要更新如果渐变是南
我的问题是:如果至少有一个梯度具有NaN并强制模型进入下一批,那么是否有更新?
编辑:或许我应该详细说明我的问题。
这是我如何申请梯度:
with tf.control_dependencies([tf.check_numerics(grad, message='Gradient %s check failed, possible NaNs' % var.name) for grad, var in grads]):
# Apply the gradients to adjust the shared variables.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
我曾经想过用tf.check_numerics
首先要验证有提示NaN的梯度,并且,然后,如果有NaN的(检查失败)我可以“通过”而不使用opt.apply_gradients
。但是,有没有办法在tf.control_dependencies
上发现错误?
答
我可以弄明白,虽然不是最优雅的方式。 我的解决方案如下: 1)首先检查所有梯度 2)如果梯度不含NaNs,则应用它们3)否则,应用伪更新(使用零值),这需要渐变覆盖。
这是我的代码:
首先定义自定义梯度:
@tf.RegisterGradient("ZeroGrad")
def _zero_grad(unused_op, grad):
return tf.zeros_like(grad)
然后定义异常处理功能:
#this is added for gradient check of NaNs
def check_numerics_with_exception(grad, var):
try:
tf.check_numerics(grad, message='Gradient %s check failed, possible NaNs' % var.name)
except:
return tf.constant(False, shape=())
else:
return tf.constant(True, shape=())
然后创造条件节点:
num_nans_grads = tf.Variable(1.0, name='num_nans_grads')
check_all_numeric_op = tf.reduce_sum(tf.cast(tf.stack([tf.logical_not(check_numerics_with_exception(grad, var)) for grad, var in grads]), dtype=tf.float32))
with tf.control_dependencies([tf.assign(num_nans_grads, check_all_numeric_op)]):
# Apply the gradients to adjust the shared variables.
def fn_true_apply_grad(grads, global_step):
apply_gradients_true = opt.apply_gradients(grads, global_step=global_step)
return apply_gradients_true
def fn_false_ignore_grad(grads, global_step):
#print('batch update ignored due to nans, fake update is applied')
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "ZeroGrad"}):
for (grad, var) in grads:
tf.assign(var, tf.identity(var, name="Identity"))
apply_gradients_false = opt.apply_gradients(grads, global_step=global_step)
return apply_gradients_false
apply_gradient_op = tf.cond(tf.equal(num_nans_grads, 0.), lambda : fn_true_apply_grad(grads, global_step), lambda : fn_false_ignore_grad(grads, global_step))