tensorflow之RNNCell及自定义Cell
我们在仿真一些论文的时候经常会遇到一些模型,对RNN或者LSTM进行了少许的修改,或者自己定义了一种RNN的结构等情况,比如前面介绍的几篇memory networks、qausi RNN 的论文,往往都需要按照自己定义的方法来构造RNN网络。所以本篇博客就主要总结一下RNNcell的用法以及如何按照自己的需求自定义RNNCell。
本文在很多地方参考这篇博客:tensorflow中RNNcell源码分析以及自定义RNNCell的方法,但是是比较老的版本,这里我全部结合tensorflow 1.6.0的代码进行改写,并添加Quasi RNN的代码进行举例。
tf中RNNCell的用法介绍
我们直接从源码的层面来看一看tf是如何实现RNNCell定义的。代码入下:
class RNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`. The optional
third input argument, `scope`, is allowed for backwards compatibility
purposes; but should be left off for new subclasses.
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: if `self.state_size` is an integer, this should be a `2-D Tensor`
with shape `[batch_size, self.state_size]`. Otherwise, if
`self.state_size` is a tuple of integers, this should be a tuple
with shapes `[batch_size, s] for s in self.state_size`.
scope: VariableScope for the created subgraph; defaults to class name.
Returns:
A pair containing:
- Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
the arity and shapes of `state`.
"""
if scope is not None:
with vs.variable_scope(scope,
custom_getter=self._rnn_get_variable) as scope:
return super(RNNCell, self).__call__(inputs, state, scope=scope)
else:
scope_attrname = "rnncell_scope"
scope = getattr(self, scope_attrname, None)
if scope is None:
scope = vs.variable_scope(vs.get_variable_scope(),
custom_getter=self._rnn_get_variable)
setattr(self, scope_attrname, scope)
with scope:
return super(RNNCell, self).__call__(inputs, state)
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
if context.in_graph_mode():
trainable = (variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables()))
else:
trainable = variable._trainable # pylint: disable=protected-access
if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights:
self._non_trainable_weights.append(variable)
return variable
@property
def state_size(self):
"""size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
"""
raise NotImplementedError("Abstract method")
@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")
def build(self, _):
# This tells the parent Layer object that it's OK to call
# self.add_variable() inside the call() method.
pass
def zero_state(self, batch_size, dtype):
"""Return zero-filled state tensor(s).
Args:
batch_size: int, float, or unit Tensor representing the batch size.
dtype: the data type to use for the state.
Returns:
If `state_size` is an int or TensorShape, then the return value is a
`N-D` tensor of shape `[batch_size, state_size]` filled with zeros.
If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size, s]` for each s in `state_size`.
"""
# Try to use the last cached zero_state. This is done to avoid recreating
# zeros, especially when eager execution is enabled.
state_size = self.state_size
is_eager = context.in_eager_mode()
if is_eager and hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
if (last_batch_size == batch_size and
last_dtype == dtype and
last_state_size == state_size):
return last_output
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
output = _zero_state_tensors(state_size, batch_size, dtype)
if is_eager:
self._last_zero_state = (state_size, batch_size, dtype, output)
return output
RNNCell是一个抽象的父类,其他的RNNcell都会继承该方法,然后具体实现其中的call()函数。从上面的定义中我们发现其主要有state_size和output_size两个属性,分别代表了隐藏层和输出层的维度。然后就是zero_state()和call()两个函数,分别用于初始化初始状态h0为全零向量和定义实际的RNNCell的操作(比如RNN就是一个**,GRU的两个门,LSTM的三个门控等,不同的RNN的区别主要体现在这个函数)。有了这个抽象类,我们接下来看看tf中BasicRNNCell、GRUCell、BasicLSTMCell三个cell的定义方法,了解不同变种RNN模型的定义方式的区别和实现方法。
这里需要注意的是在最新的tensorflow中,修改了原来的代码通过build和call函数把核心代码的初始化和执行。添加完build之后注意返回True,并且build是在call执行之前只执行一次
class BasicRNNCell(_LayerRNNCell):
"""The most basic RNN cell.
Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
"""
def __init__(self, num_units, activation=None, reuse=None, name=None):
super(BasicRNNCell, self).__init__(_reuse=reuse, name=name)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
self.built = True
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
output = self._activation(gate_inputs)
return output, output
最简单的RNN结构如上图所示,可以看出BasicRNNCell中把state_size和output_size定义成相同,而且ht和output也是相同的(看call函数的输出是两个output,也就是其并未定义输出部分)。再看一下其主要功能实现就是call函数的第一行,就是input和前一时刻状态state经过一个线性函数在经过一个**函数即可,也就是最普通的RNN定义方式。也就是说output = new_state = f(W * input + U * state + B)。接下来我们看一下GRU的定义:
class GRUCell(_LayerRNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
Args:
num_units: int, The number of units in the GRU cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
kernel_initializer: (optional) The initializer to use for the weight and
projection matrices.
bias_initializer: (optional) The initializer to use for the bias.
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
"""
def __init__(self,
num_units,
activation=None,
reuse=None,
kernel_initializer=None,
bias_initializer=None,
name=None):
super(GRUCell, self).__init__(_reuse=reuse, name=name)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
self._gate_kernel = self.add_variable(
"gates/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, 2 * self._num_units],
initializer=self._kernel_initializer)
self._gate_bias = self.add_variable(
"gates/%s" % _BIAS_VARIABLE_NAME,
shape=[2 * self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.constant_initializer(1.0, dtype=self.dtype)))
self._candidate_kernel = self.add_variable(
"candidate/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units],
initializer=self._kernel_initializer)
self._candidate_bias = self.add_variable(
"candidate/%s" % _BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.zeros_initializer(dtype=self.dtype)))
self.built = True
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._gate_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
value = math_ops.sigmoid(gate_inputs)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
r_state = r * state
candidate = math_ops.matmul(
array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
candidate = nn_ops.bias_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_h = u * state + (1 - u) * c
return new_h, new_h
相比BasicRNNCell只改变了call函数部分,增加了重置门和更新门两部分,分别由r和u表示。然后c表示要更新的状态值。其对应的公式如如下所示:
接下来再看一下BasicLSTMCell的实现方法,相比GRU,LSTM又多了一个输出门,而且又新增添了一个C表示其内部状态,然后将h和c以tuple的形式返回作为LSTM内部的状态变量。其内部结构和公式表示如下图所示:
具体实现的公式如下:
class BasicLSTMCell(_LayerRNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
that follows.
"""
def __init__(self, num_units, forget_bias=1.0,
state_is_tuple=True, activation=None, reuse=None, name=None):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained
checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such
cases.
When restoring from CudnnLSTM-trained checkpoints, must use
`CudnnCompatibleLSTMCell` instead.
"""
super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + h_depth, 4 * self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
self.built = True
def call(self, inputs, state):
"""Long short-term memory cell (LSTM).
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size, self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size, 2 * self.state_size]`.
Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
"""
sigmoid = math_ops.sigmoid
one = constant_op.constant(1, dtype=dtypes.int32)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
multiply(sigmoid(i), self._activation(j)))
new_h = multiply(self._activation(new_c), sigmoid(o))
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
2.tf中自定义RNNCell的方法
看完basic RNN 、GRU和LSTM cell的实现方案,我觉得应该不难想象出自定义RNNCell的方法,那就是继承RNNCell这个抽象类,然后实现init、state_size、output_size、build、call四个函数就行了,其中在call函数中实现自己需要的功能即可。(注意:build只调用一次)。我们来结合之前仿真过得Quasi RNN这篇文章中使用的带来来介绍一下。其公式及相关介绍详见:博客
相关代码: