tensorflow源码解读 循环网络
用tensorflow搭建循环神经网络时,需要调用tf.nn.rnn_cell模块下的api,虽然写在nn包里,但总感觉rnn模块像一个黑盒,完全不像nn模块下的风格,通过源码可以发现,rnn模块本质上是一种实现了Layers的类,放在tf.Layers模块下感觉更好理解。
上图是RNN模块的UML图,Layers是所有类的基类,所以rnn是一种layers,在第一次调用__call__调用build方法创建变量,计算逻辑都包含在call函数中。RNN模块采用装饰者模式构建,底下是四大基本类型,以Wrapper是装饰类型。注意MutiRNNCell装饰的是一个RnnCell列表。以上所有的类都可以看做是一种RnnCell类,RnnCell类每调用一次__call__()方法,就会在timestamp维度进行一次计算,若想执行Rnn操作需要循环执行__call__()方法。
为了方便,tensorflow提供了一些自动在timestamp上循环执行__call__()的函数:
cell:任意一种RnnCell
inputs:RnnCell需要的输入格式
initial_state:初始化时可以调用RnnCell.zero_state()方法
其他函数也是这种调用方式