Tensorflow:为什么tf.case给我错误的结果?
问题描述:
我试图使用tf.case
(https://www.tensorflow.org/api_docs/python/tf/case)来有条件地更新张量。如图所示,我试图在global_step == 2
时将learning_rate
更新为0.01
,并且在global_step == 4
时将0.001
更新为0.01
。Tensorflow:为什么tf.case给我错误的结果?
但是,当global_step == 2
,我已经得到learning_rate = 0.001
。经过进一步检查,看起来tf.case
给我错误的结果global_step == 2
(我得到0.001
而不是0.01
)。即使0.01
的谓词正在评估为True,并且0.001
的谓词正在评估为False,也会发生这种情况。
我做错了什么,或者这是一个错误?
TF版本:1.0.0
代码:
import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')
# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
pred = tf.equal(global_step, step)
fn_tensor = tf.constant(new_rate, dtype=tf.float32)
cases.append((pred, lambda: fn_tensor))
case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)
print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, update, updated_learning_rate])
sess.run(train_op)
结果:
1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
答
这是在回答https://github.com/tensorflow/tensorflow/issues/8776
事实证明,tf.case
行为我如果在fn_tensors
中,lambda表达式返回一个在lambda之外创建的张量,那么它就是undefined。正确的用法是定义lambda表达式,使它们返回一个新创建的张量。
根据连锁Github的问题,这种用法是必需的,因为tf.case
必须以挂钩张量的输入谓词的正确的分支创建张量本身。