Tensorflow中BasicRNNCell的相关参数shape说明
按道理看完RNN的原理之后,我们就应该来用某种框架来实现了。可偏偏在RNN的实现上,对于一个初学者来说Tensorflow的表达总是显得那么生涩难懂,比起CNN那确实是差了一点。比如里面的参数就显示不是那么的友好。num_units
到底指啥?原谅我最开始以为指的是RNN单元的个数。zero_state()
中的参数为啥要指定batch_size
?
1.结论
先回忆一下RNN的计算公式:
output_size = 10
batch_size = 32
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32,shape=[batch_size,150])
h0 = cell.zero_state(batch_size=batch_size,dtype=tf.float32)
output,h1 = cell.call(input,h0)
上面是一个RNN单元最简单的定义形式,可是每个参数又到底是什么含义呢?我们知道一个最基本的RNN单元中有三个可训练的参数,以及两个输入变量。所以我们在构造RNN的时候就需要指定各个参数的维度了。可上面6行代码中,各个参数又是谁跟谁呢? 下图就是直接结果。
结合着上图和代码,可以发现:
第一:第3行代码的num_units=output_size
就告诉我们,最终输出的类别数是output_size
(例如:10个数字的可能性;),以及参数的第二个维度为output_size
;
第二:第4行代码的shape=[batch_size,150]
就告诉了我们余下所有参数的形状;
2.怎么来的
class BasicRNNCell(RNNCell):
"""The most basic RNN cell.
Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
def __init__(self, num_units, activation=None, reuse=None):
super(BasicRNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
def _linear(args,
output_size,
bias,
bias_initializer=None,
kernel_initializer=None):
-------此处删除了很多行--------------------------
with vs.variable_scope(outer_scope) as inner_scope:
inner_scope.set_partitioner(None)
if bias_initializer is None:
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
biases = vs.get_variable(
_BIAS_VARIABLE_NAME, [output_size],
dtype=dtype,
initializer=bias_initializer)
return nn_ops.bias_add(res, biases)
从Class BasicRNNCell
的源码第22行可以看出num_units
和output_size
是一回事;从第43行可以看出,output_size
指的是偏置B的维度,只要弄清楚了这两点,其他的就一目了然了
2.验证
2.1 从计算维度来验证
import tensorflow as tf
output_size = 10
batch_size = 32
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
print(cell.output_size)
input = tf.placeholder(dtype=tf.float32,shape=[batch_size,150])
h0 = cell.zero_state(batch_size=batch_size,dtype=tf.float32)
output,h1 = cell.call(input,h0)
print(output)
print(h1)
按照上面的推断各个参数的维度为:input: [32,150]; W: [150,10]; h0: [32,10]; U: [10,10] B: [10]
;所以最终输出的维度就应该为[32,10]
10
Tensor("Tanh:0", shape=(32, 10), dtype=float32)
Tensor("Tanh:0", shape=(32, 10), dtype=float32)
2.1 从计算结果来验证
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import variable_scope as vs
output_size = 4
batch_size = 3
dim = 5
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32, shape=[batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
output, h1 = cell.call(input, h0)
x = np.array([[1, 2, 1, 1, 1], [2, 0, 0, 1, 1], [2, 1, 0, 1, 0]])
scope = vs.get_variable_scope()
with vs.variable_scope(scope,reuse=True) as outer_scope:
weights = vs.get_variable(
"kernel", [9, output_size],
dtype=tf.float32,
initializer= None)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a,b,w= sess.run([output,h1,weights],feed_dict={input:x})
print('output:')
print(a)
print('h1:')
print(b)
print("weights:")
print(w)# shape = (9,4)
state = np.zeros(shape=(3,4))# shape = (3,4)
all_input = np.concatenate((x,state),axis=1)# shape = (3,9)
result = np.tanh(np.matmul(all_input,w))
print('result:')
print(result)
由以上代码可知:input: [3,5]; W: [5,4]; h0: [3,4]; U: [4,4] B: [4]
;所以最终输出的维度就应该为[3,4]
注:
1.Tensorflow在实现的时候把W,U合并成了一个矩阵,把input和h0也合并成了一个矩阵,所以weight的形状才为(9,4);
2.此处偏置为0;
结果:
>>
weights:
[[ 0.590749 0.31745368 -0.27975678 0.33500886]
[-0.02256793 -0.34533614 -0.09014118 -0.5189797 ]
[-0.24466929 0.17519772 0.20553339 -0.25595042]
[-0.48579523 0.67465043 0.62406075 -0.32061592]
[-0.0713594 0.3825792 0.6132684 0.00536895]
[ 0.43795645 0.55633724 0.31295568 -0.37173718]
[ 0.6170727 0.14996111 -0.321027 -0.624057 ]
[ 0.42747557 0.4424585 -0.59979784 0.23592204]
[-0.0294565 0.3372593 -0.14695019 0.07108325]]
output:
[[-0.2507479 0.69584984 0.7542856 -0.8549179 ]
[ 0.5541449 0.9344188 0.5900975 0.3405997 ]
[ 0.5870382 0.74615407 -0.0255884 -0.16797088]]
h1:
[[-0.2507479 0.69584984 0.7542856 -0.8549179 ]
[ 0.5541449 0.9344188 0.5900975 0.3405997 ]
[ 0.5870382 0.74615407 -0.0255884 -0.16797088]]
result:
[[-0.25074791 0.69584978 0.75428552 -0.85491791]
[ 0.55414493 0.93441886 0.59009744 0.34059968]
[ 0.58703823 0.74615404 -0.02558841 -0.16797091]]
result是通过numpy计算得到的输出值!