如何迭代tensorflow中的张量?

问题描述:

我收到以下错误:如何迭代tensorflow中的张量?

TypeError: 'Tensor' object is not iterable.

我试图用一个占位符和FIFOQueue养活数据。但是这里的问题是我无法批量处理数据。任何人都可以提供解决方案吗?

我是TensorFlow中的新成员,混淆了占位符和张量的概念。

下面是代码:

#-*- coding:utf-8 -*- 
import tensorflow as tf 
import sys 

q = tf.FIFOQueue(1000,tf.string) 
label_ph = tf.placeholder(tf.string,name="label") 
enqueue_op = q.enqueue_many(label_ph) 
qr = tf.train.QueueRunner(q,enqueue_op) 
m = q.dequeue() 

sess_conf = tf.ConfigProto() 
sess_conf.gpu_options.allow_growth = True 
sess = tf.Session(config=sess_conf) 
sess.run(tf.global_variables_initializer()) 
coord = tf.train.Coordinator() 
tf.train.start_queue_runners(coord=coord, sess=sess) 

image_batch = tf.train.batch(
     m,batch_size=3, 
     enqueue_many=True, 
     capacity=9 
     ) 

for i in range(0, 10): 
    print "-------------------------" 
    #print(sess.run(q.dequeue())) 
    a = ['a','b','c','a1','b1','c1','a','b','c2','a','b','c3',] 
    sess.run(enqueue_op,{label_ph:a}) 
    b = sess.run(m) 
    print b 
q.close() 
coord.request_stop() 

我认为你正在运行到同样的问题我。当您运行会话时,您实际上无法访问数据,您可以改为访问数据图。所以你应该像图中的节点那样考虑张量对象,而不是像你可以做的事情那样的大块数据。如果你想对图做些事情,你必须调用tf。*函数,或者在sess.run()调用中获取变量。当你这样做时,tensorflow会找出如何根据它的依赖关系获取数据并运行计算。

至于你的问题,看看这个页面上的QueueRunner例子。 https://www.tensorflow.org/programmers_guide/threading_and_queues

另一种方法,你可以做到这一点(这是我切换到)是你可以洗牌你的数据在CPU上,然后一次复制它。然后,您可以跟踪自己正在进行的步骤,并获取该步骤的数据。我帮助保持gpu数据并减少内存副本。

all_shape = [num_batches, batch_size, data_len] 
    local_shape = [batch_size, data_len] 

    ## declare the data variables 
    model.all_data = tf.Variable(tf.zeros(all_shape), dtype=tf.float32) 
    model.step_data=tf.Variable(tf.zeros(local_shape), dtype=tf.float32) 
    model.step = tf.Variable(0, dtype=tf.float32, trainable=False, name="step") 

    ## then for each step in the epoch grab the data 
    index = tf.to_int32(model.step) 
    model.step_data = model.all_data[index] 

    ## inc the step counter 
    tf.assign_add(model.step, 1.0, name="inc_counter") 
+0

我需要使用批处理。那么你能否提供批量解决方案? – JerryWind

+0

我上面的代码是一般的想法。您需要[num_batches,batch_size,data_len]的3D张量,然后为每个批次抓取所需的切片。 – ReverseFall