为什么这个tensorflow循环需要这么多的内存?

问题描述:

我有一个复杂的网络的一个人为的版本:为什么这个tensorflow循环需要这么多的内存?

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

for i in range(int(1e6)): 
    a = a * b 

我的直觉是,这应该需要很少的内存。只是初始数组分配的空间和一系列利用节点的命令并在每一步中覆盖存储在张量'a'中的内存。但内存使用增长相当迅速。

这里发生了什么,当我计算张量并覆盖很多次时,如何减少内存使用量?

编辑:

由于雅罗斯拉夫的建议解决方案横空出世使用while_loop以尽量减少对图中的节点的数量要。这样做效果很好,速度更快,所需内存更少,并且全部包含在图表中。

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

cond = lambda _i, _1, _2: tf.less(_i, int(1e6)) 
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b] 

i = tf.constant(0) 
output = tf.while_loop(cond, body, [i, a, b]) 

with tf.Session() as sess: 
    result = sess.run(output) 
    print(result) 

a*b命令转换为tf.mul(a, b),这相当于tf.mul(a, b, g=tf.get_default_graph())。此命令将Mul节点添加到当前的Graph对象,因此您正尝试将100万个Mul节点添加到当前图形。这也是有问题的,因为你不能序列化大于2GB的Graph对象,所以一旦你处理这么大的图,有一些检查可能会失败。

我建议阅读Programming Models for Deep Learning MXNet人。 TensorFlow是其术语中的“象征性”编程,您将它视为必不可少的。

为了得到你想要使用Python循环中,您可以构建乘法运算一次,反复运行它的东西,用feed_dict养活更新

mul_op = a*b 
result = sess.run(a) 
for i in range(int(1e6)): 
    result = sess.run(mul_op, feed_dict={a: result}) 

为了更高的效率,你可以使用tf.Variable对象和var.assign避免Python的< - > TensorFlow数据传输

+0

这很有道理。我只想调用sess.run一次,并且一步处理从输入到输出的所有计算,因为我正在反向传播大图。是否可以在不添加额外节点的情况下执行此循环? – jstaker7

+0

如果您在'tf.Variable'对象中将输入保存到* b中,您将隔离它的依赖关系,因此您可以在该节点上执行'sess.run一百万次,而无需评估图中的任何其他内容 –

+0

不确定我很了解。假设'b'是一个可训练的权重矩阵,在back-prop期间得到更新。如果我多次调用sess.run,不但会招致多次sess.run调用的额外开销,而且会将该计算与梯度计算断开连接,并需要做一些繁琐的工作以确保其得到正确更新。这些假设是否正确?我想我希望有更好的方法来处理这个图表。 – jstaker7