Tensorflow:为什么tf.case给我错误的结果?

问题描述:

我试图使用tf.casehttps://www.tensorflow.org/api_docs/python/tf/case)来有条件地更新张量。如图所示,我试图在global_step == 2时将learning_rate更新为0.01,并且在global_step == 4时将0.001更新为0.01Tensorflow:为什么tf.case给我错误的结果?

但是,当global_step == 2,我已经得到learning_rate = 0.001。经过进一步检查,看起来tf.case给我错误的结果global_step == 2(我得到0.001而不是0.01)。即使0.01的谓词正在评估为True,并且0.001的谓词正在评估为False,也会发生这种情况。

我做错了什么,或者这是一个错误?

TF版本:1.0.0

代码:

import tensorflow as tf 

global_step = tf.Variable(0, dtype=tf.int64) 
train_op = tf.assign(global_step, global_step + 1) 
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate') 

# Update the learning_rate tensor conditionally 
# When global_step == 2, update to 0.01 
# When global_step == 4, update to 0.001 
cases = [] 
case_tensors = [] 
for step, new_rate in [(2, 0.01), (4, 0.001)]: 
    pred = tf.equal(global_step, step) 
    fn_tensor = tf.constant(new_rate, dtype=tf.float32) 
    cases.append((pred, lambda: fn_tensor)) 
    case_tensors.append((pred, fn_tensor)) 
update = tf.case(cases, default=lambda: learning_rate) 
updated_learning_rate = tf.assign(learning_rate, update) 

print tf.__version__ 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    for _ in xrange(6): 
     print sess.run([global_step, case_tensors, update, updated_learning_rate]) 
     sess.run(train_op) 

结果:

1.0.0 
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1] 
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1] 
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001] 
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001] 
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001] 
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001] 

这是在回答https://github.com/tensorflow/tensorflow/issues/8776

事实证明,tf.case行为我如果在fn_tensors中,lambda表达式返回一个在lambda之外创建的张量,那么它就是undefined。正确的用法是定义lambda表达式,使它们返回一个新创建的张量。

根据连锁Github的问题,这种用法是必需的,因为tf.case必须以挂钩张量的输入谓词的正确的分支创建张量本身。