自学笔记:LSTM理论联系实际的TENSORFLOW代码研究, state和ouput的数据结构
看了一些讲述LSTM的原理,基本上搞清楚了,不过需要理解代码还有一段路要走。
网上找了一个最简单的示例,不过无法在tensorflow1.3运行,花了一些时间,解决了兼容问题,下面的代码可以运行。
#运行版本,tensorflow1.4
#源代码从网上COPY的示例改编,原来的示例无法在tensorflow1.3以上运行
import tensorflow as tf;
import numpy as np;
units_num = 10 #隐藏层节点
batch_size = 2 #训练批次
X = tf.random_normal(shape=[batch_size,5,7], dtype=tf.float32)
#X = tf.reshape(X, [-1, 5, 6])
def cell(): #一定要定义成函数,否则出错, tensorflow1.X不兼容
return tf.nn.rnn_cell.BasicLSTMCell(units_num) #10代表10个节点
#也可以换成别的,比如GRUCell,BasicRNNCell等等
lstm_multi = tf.nn.rnn_cell.MultiRNNCell([cell() for _ in range(2)], state_is_tuple=True)
#mlstm_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(layer_num)],state_is_tuple=True)
state = lstm_multi.zero_state(batch_size, tf.float32)
output, state = tf.nn.dynamic_rnn(lstm_multi, X, initial_state=state, time_major=False)
finaloutput = output[-1] #取最后一个output
statec= state[-1].c #取最后一个state
stateh = state[-1].h
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print('----------------output--------------------')
print(sess.run(output)) #2*5*10
print('----------------finaloutput--------------------')
print(sess.run(finaloutput)) #5*10
print('----------------state--------------------')
print(sess.run(state)) #2*4*10
print('-----------------statec-------------------')
print(sess.run(statec)) #2*10
print('----------------stateh--------------------')
print(sess.run(stateh)) #2*10
'''
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
----------------output--------------------
[[[-0.00637375 -0.02218931 -0.00364614 0.00871621 0.01000606 0.00688382
0.0173985 -0.01417144 -0.00514521 0.00809157]
[-0.02739093 0.00731545 -0.01414015 0.00776639 0.01201528 -0.00530205
-0.00373759 -0.00243396 0.01967375 0.0176452 ]
[-0.05032041 0.01292216 -0.01505091 0.00758813 0.00881143 0.0026364
0.00501611 -0.01001035 0.01516667 0.03455975]
[-0.03608039 -0.01159106 -0.00147149 0.01253749 0.01457058 0.00761829
0.0266578 -0.0391868 -0.01949262 0.03059825]
[-0.01878959 -0.0512354 -0.00916542 0.0181381 0.01850957 -0.00772178
0.05093804 -0.04717548 -0.03797766 0.00213585]]
[[-0.00301182 -0.04609392 -0.04139089 -0.00270795 -0.01418001 -0.02598612
0.03196315 0.01515245 0.00444449 -0.02480585]
[ 0.0216604 -0.07657862 -0.0436976 -0.00264414 -0.01624844 -0.0388252
0.04456366 -0.00641088 -0.00966516 -0.02996463]
[ 0.03529613 -0.06516693 -0.038265 -0.0112583 -0.01810452 -0.03949273
0.0325823 -0.00466524 -0.00729331 -0.03314594]
[ 0.05150301 -0.08670978 -0.06885949 -0.02202794 -0.05140235 -0.05352144
0.03970223 0.02086095 -0.00696997 -0.06665371]
[ 0.05955734 -0.06664775 -0.06081596 0.00371452 -0.03970711 -0.02188079
0.03479016 0.03526397 0.00916961 -0.06597719]]]
----------------finaloutput--------------------
[[-0.04641263 0.01628485 -0.03593837 -0.01126753 -0.02915012 -0.00858944
-0.00375725 0.03282307 0.01455755 -0.00436677]
[-0.07043284 0.02731676 -0.04114634 -0.02013678 -0.04268811 -0.02129938
0.00629975 0.02324007 0.00693215 0.01216679]
[-0.07693886 0.01834853 -0.04679859 -0.03363486 -0.0439391 -0.06869462
0.00808731 0.0087936 0.00530997 0.02094404]
[-0.05964042 -0.02323203 -0.01288006 -0.02311298 -0.00924513 -0.1003315
0.02314057 -0.02797622 -0.01521164 0.03499829]
[-0.05991254 -0.04570793 -0.02727396 -0.02612418 -0.00901282 -0.12065098
0.02534271 -0.00855778 -0.00444907 0.01586689]]
----------------state--------------------
(LSTMStateTuple(c=array([[-0.66914773, 0.07839034, 0.33994618, 0.8456924 , 0.10953232,
-0.39234626, 0.13399592, -0.15174234, 0.00130345, -0.30228698],
[-0.38686785, -0.13705038, -0.1424486 , 0.0899242 , 0.24239531,
0.08536939, -0.27032876, -0.24645516, -0.00084634, 0.18780154]], dtype=float32), h=array([[ -2.64760315e-01, 4.65154909e-02, 1.34974703e-01,
3.55697572e-01, 3.73032168e-02, -2.32693255e-01,
7.82607347e-02, -6.97162896e-02, 3.86156840e-04,
-1.59545243e-01],
[ -1.43306926e-01, -7.28544220e-02, -6.16334565e-02,
4.14225310e-02, 1.12995833e-01, 4.02112342e-02,
-1.64336443e-01, -1.15731291e-01, -2.68079224e-04,
8.61014202e-02]], dtype=float32)), LSTMStateTuple(c=array([[-0.12525879, -0.02339883, -0.15596688, 0.09201161, -0.13064483,
0.1607396 , 0.14479131, 0.06554902, -0.0195029 , -0.08419077],
[-0.02751132, 0.08955061, 0.07464057, 0.07965946, 0.10505884,
0.16445892, -0.04261293, -0.04892696, 0.03463147, 0.07745198]], dtype=float32), h=array([[-0.06212955, -0.01064036, -0.08202581, 0.04435679, -0.06457365,
0.07770626, 0.0736967 , 0.03266109, -0.00938224, -0.04352244],
[-0.01347394, 0.04540392, 0.03609779, 0.04107576, 0.05197879,
0.08114899, -0.02111933, -0.02465794, 0.01692747, 0.03863694]], dtype=float32)))
-----------------statec-------------------
[[-0.14902619 0.21614444 -0.05753845 -0.07061018 -0.17564972 -0.08847831
-0.07436903 0.09744829 0.07963906 0.06430003]
[ 0.09381064 -0.0832063 0.00594144 0.05869648 0.03069258 0.02718895
0.00898496 -0.03545038 0.04609191 0.00290151]]
----------------stateh--------------------
[[-0.00620752 0.03671843 0.07395449 0.00736501 0.0685369 0.03242806
-0.03287118 -0.08238971 -0.01176278 0.08281735]
[ 0.04372304 -0.07892949 -0.01827243 -0.02095079 -0.01729526 -0.05310233
0.02035648 -0.05144129 -0.0138404 -0.01711897]]
'''
总结:
输出维度是10维,输入的维度2*5*7 输出变成2*5*10
output是所有batch的数据2*5*10
state也是所有batch的数据2*10 批次乘以输出节点
state数据分为state_c和state_h,分别对应的含义如图的Ht和Ht的数据:
state_c和state_h分别是2*10的数组,为什么是2*10而不是1*10,如图猜测是state_h一份传给输出output,一份传给下一层的状态state_h, state_c一份传给下一层的state_c,一份传给下一层的state_h