如何迭代tensorflow中的张量?
问题描述:
我收到以下错误:如何迭代tensorflow中的张量?
TypeError: 'Tensor' object is not iterable.
我试图用一个占位符和FIFOQueue
养活数据。但是这里的问题是我无法批量处理数据。任何人都可以提供解决方案吗?
我是TensorFlow中的新成员,混淆了占位符和张量的概念。
下面是代码:
#-*- coding:utf-8 -*-
import tensorflow as tf
import sys
q = tf.FIFOQueue(1000,tf.string)
label_ph = tf.placeholder(tf.string,name="label")
enqueue_op = q.enqueue_many(label_ph)
qr = tf.train.QueueRunner(q,enqueue_op)
m = q.dequeue()
sess_conf = tf.ConfigProto()
sess_conf.gpu_options.allow_growth = True
sess = tf.Session(config=sess_conf)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(coord=coord, sess=sess)
image_batch = tf.train.batch(
m,batch_size=3,
enqueue_many=True,
capacity=9
)
for i in range(0, 10):
print "-------------------------"
#print(sess.run(q.dequeue()))
a = ['a','b','c','a1','b1','c1','a','b','c2','a','b','c3',]
sess.run(enqueue_op,{label_ph:a})
b = sess.run(m)
print b
q.close()
coord.request_stop()
答
我认为你正在运行到同样的问题我。当您运行会话时,您实际上无法访问数据,您可以改为访问数据图。所以你应该像图中的节点那样考虑张量对象,而不是像你可以做的事情那样的大块数据。如果你想对图做些事情,你必须调用tf。*函数,或者在sess.run()调用中获取变量。当你这样做时,tensorflow会找出如何根据它的依赖关系获取数据并运行计算。
至于你的问题,看看这个页面上的QueueRunner例子。 https://www.tensorflow.org/programmers_guide/threading_and_queues
另一种方法,你可以做到这一点(这是我切换到)是你可以洗牌你的数据在CPU上,然后一次复制它。然后,您可以跟踪自己正在进行的步骤,并获取该步骤的数据。我帮助保持gpu数据并减少内存副本。
all_shape = [num_batches, batch_size, data_len]
local_shape = [batch_size, data_len]
## declare the data variables
model.all_data = tf.Variable(tf.zeros(all_shape), dtype=tf.float32)
model.step_data=tf.Variable(tf.zeros(local_shape), dtype=tf.float32)
model.step = tf.Variable(0, dtype=tf.float32, trainable=False, name="step")
## then for each step in the epoch grab the data
index = tf.to_int32(model.step)
model.step_data = model.all_data[index]
## inc the step counter
tf.assign_add(model.step, 1.0, name="inc_counter")
我需要使用批处理。那么你能否提供批量解决方案? – JerryWind
我上面的代码是一般的想法。您需要[num_batches,batch_size,data_len]的3D张量,然后为每个批次抓取所需的切片。 – ReverseFall