RNN 训练算法 —— 反向传播(Backpropagation Through Time)
参见基本框架:https://goodgoodstudy.blog.****.net/article/details/109245095
问题描述
考虑模型循环网络模型:
x
(
k
)
=
f
[
W
x
(
k
−
1
)
]
(1)
x(k) = f[Wx(k-1)] \tag1{}
x(k)=f[Wx(k−1)](1)
其中
x
(
k
)
∈
R
N
x(k) \in R^N
x(k)∈RN表示网络节点状态,
W
∈
R
N
×
N
W\in R^{N\times N}
W∈RN×N表示网络结点之间相互连接的权重,网络的输出节点为
{
x
i
(
k
)
∣
i
∈
O
}
\{x_i(k)| i\in O\}
{xi(k)∣i∈O},
O
O
O为所有输出(或称“观测”)单元的下标集合
训练的目标是为了减少观测状态和预期值之间误差,即最小化损失函数:
E
=
1
2
∑
k
=
1
K
∑
i
∈
O
[
x
i
(
k
)
−
d
i
(
k
)
]
2
(2)
E = \frac{1}{2}\sum_{k=1}^K \sum_{i\in O} [x_i(k) - d_i(k)]^2 \tag{2}
E=21k=1∑Ki∈O∑[xi(k)−di(k)]2(2)
其中
d
i
(
k
)
d_i(k)
di(k) 表示
k
k
k 时刻第
i
i
i 个节点的预期值
采用梯度下降法更新
W
W
W:
W
+
=
W
−
η
d
E
d
W
W_+ = W - \eta \frac{dE}{dW}
W+=W−ηdWdE
符号约定
W
≡
[
—–
w
1
T
—–
⋮
—–
w
N
T
—–
]
N
×
N
W \equiv \begin{bmatrix} \text{-----} w_1^T \text{-----} \\ \vdots \\ \text{-----} w_N^T \text{-----} \end{bmatrix}_{N\times N}
W≡⎣⎢⎡—–w1T—–⋮—–wNT—–⎦⎥⎤N×N
将矩阵
W
W
W 拉成列向量,记为
w
w
w
w
=
[
w
1
T
,
⋯
,
w
N
T
]
T
∈
R
N
2
w = [w_1^T, \cdots, w_N^T]^T \in R^{N^2}
w=[w1T,⋯,wNT]T∈RN2
把所有时间的状态拼成列向量,记为
x
x
x
x
=
[
x
T
(
1
)
,
⋯
,
x
T
(
K
)
]
T
∈
R
N
K
x = [x^T(1), \cdots, x^T(K)]^T \in R^{NK}
x=[xT(1),⋯,xT(K)]T∈RNK
将RNN 的训练视为约束优化问题,(1)式转化成约束条件:
g
(
k
)
≡
f
[
W
x
(
k
−
1
)
]
−
x
(
k
)
=
0
,
k
=
1
,
…
,
K
(3)
g(k) \equiv f[Wx(k-1)] - x(k) =0, \quad k=1,\ldots ,K \tag{3}
g(k)≡f[Wx(k−1)]−x(k)=0,k=1,…,K(3)
记
g
=
[
g
T
(
1
)
,
…
,
g
T
(
K
)
]
T
∈
R
N
K
g = [g^T(1), \ldots, g^T(K)]^T \in R^{NK}
g=[gT(1),…,gT(K)]T∈RNK
则
0
=
d
g
(
x
(
w
)
,
w
)
d
w
=
∂
g
(
x
(
w
)
,
w
)
∂
x
∂
x
(
w
)
∂
w
+
∂
g
(
x
(
w
)
,
w
)
∂
w
(4)
0 = \frac{dg(x(w),w)}{dw} = \frac{\partial g(x(w),w)}{\partial x}\frac{\partial x(w)}{\partial w} + \frac{\partial g(x(w),w)}{\partial w} \tag{4}
0=dwdg(x(w),w)=∂x∂g(x(w),w)∂w∂x(w)+∂w∂g(x(w),w)(4)
d
E
d
w
=
∂
E
∂
x
(
∂
g
∂
x
)
−
1
∂
g
∂
w
(5)
\frac{dE}{dw} = \frac{\partial E}{\partial x} \left(\frac{\partial g}{\partial x}\right)^{-1} \frac{\partial g}{\partial w} \tag{5}
dwdE=∂x∂E(∂x∂g)−1∂w∂g(5)
(5)中三项如下:
1.
∂
E
∂
x
=
[
e
(
1
)
,
…
,
e
(
K
)
]
e
i
(
k
)
=
{
x
i
(
k
)
−
d
i
(
k
)
,
if
i
∈
O
,
0
,
otherwise.
k
∈
1
,
…
,
K
.
\begin{aligned} \frac{\partial E}{\partial x} &= [e(1), \ldots, e(K)] \\\\ e_i(k)&= \begin{cases} x_i(k) - d_i(k), &\text{if } i\in O, \\ 0, &\text{otherwise. } \end{cases} k \in 1,\ldots,K. \end{aligned}
∂x∂Eei(k)=[e(1),…,e(K)]={xi(k)−di(k),0,if i∈O,otherwise. k∈1,…,K.
2.
∂
g
∂
x
=
[
−
I
0
0
…
0
D
(
1
)
W
−
I
0
…
0
0
D
(
2
)
W
−
I
…
0
⋮
⋮
⋮
⋱
⋮
0
0
0
D
(
K
−
1
)
W
−
I
]
N
K
×
N
K
\frac{\partial g}{\partial x} = \begin{bmatrix} -I & 0& 0 &\ldots & 0\\ D(1)W & -I & 0 &\ldots & 0 \\ 0 & D(2)W & -I & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & D(K-1)W& -I \end{bmatrix}_{NK\times NK}
∂x∂g=⎣⎢⎢⎢⎢⎢⎡−ID(1)W0⋮00−ID(2)W⋮000−I⋮0………⋱D(K−1)W000⋮−I⎦⎥⎥⎥⎥⎥⎤NK×NK
其中
D
(
j
)
=
[
f
′
(
w
1
T
x
(
j
)
)
0
⋱
0
f
′
(
w
N
T
x
(
j
)
)
]
D(j)= \begin{bmatrix} f'(w_1^Tx(j)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(j)) \end{bmatrix}
D(j)=⎣⎡f′(w1Tx(j))0⋱0f′(wNTx(j))⎦⎤
∂
g
∂
w
=
[
D
(
0
)
X
(
0
)
D
(
1
)
X
(
1
)
⋮
D
(
K
−
1
)
X
(
K
−
1
)
]
\frac{\partial g}{\partial w} = \begin{bmatrix} D(0)X(0)\\ D(1)X(1) \\ \vdots \\ D(K-1)X(K-1) \end{bmatrix}
∂w∂g=⎣⎢⎢⎢⎡D(0)X(0)D(1)X(1)⋮D(K−1)X(K−1)⎦⎥⎥⎥⎤
其中
X
(
k
)
≜
[
x
T
(
k
)
x
T
(
k
)
⋱
x
T
(
k
)
]
N
×
N
2
X(k) \triangleq \begin{bmatrix} x^T(k) &&& \\ & x^T(k)&& \\ && \ddots & \\ &&& x^T(k) \end{bmatrix}_{N\times N^2}
X(k)≜⎣⎢⎢⎡xT(k)xT(k)⋱xT(k)⎦⎥⎥⎤N×N2
反向传播
令
δ
=
∂
E
∂
x
(
∂
g
∂
x
)
−
1
∈
R
1
×
N
K
(6)
\delta = \frac{\partial E}{\partial x} \left(\frac{\partial g}{\partial x}\right)^{-1} \in R^{1\times NK}\tag{6}
δ=∂x∂E(∂x∂g)−1∈R1×NK(6)
然后计算
d
E
d
w
=
−
δ
∂
g
∂
w
\frac{dE}{dw} =- \delta \frac{\partial g}{\partial w}
dwdE=−δ∂w∂g
(6)式变形为:
δ
∂
g
∂
x
=
∂
E
∂
x
\delta \frac{\partial g}{\partial x} = \frac{\partial E}{\partial x}
δ∂x∂g=∂x∂E
令
δ
=
[
δ
(
1
)
…
δ
(
K
)
]
1
×
N
K
,
δ
(
k
)
∈
R
1
×
N
\delta = \begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix}_{1\times NK}, \quad \delta(k) \in R^{1\times N}
δ=[δ(1)…δ(K)]1×NK,δ(k)∈R1×N
则有
[
δ
(
1
)
…
δ
(
K
)
]
[
−
I
0
0
…
0
D
(
1
)
W
−
I
0
…
0
0
D
(
2
)
W
−
I
…
0
⋮
⋮
⋮
⋱
⋮
0
0
0
D
(
K
−
1
)
W
−
I
]
=
[
e
(
1
)
,
…
,
e
(
K
)
]
\begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix} \begin{bmatrix} -I & 0& 0 &\ldots & 0\\ D(1)W & -I & 0 &\ldots & 0 \\ 0 & D(2)W & -I & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & D(K-1)W& -I \end{bmatrix}= [e(1), \ldots, e(K)]
[δ(1)…δ(K)]⎣⎢⎢⎢⎢⎢⎡−ID(1)W0⋮00−ID(2)W⋮000−I⋮0………⋱D(K−1)W000⋮−I⎦⎥⎥⎥⎥⎥⎤=[e(1),…,e(K)]
解得:
δ
(
K
)
=
−
e
(
K
)
δ
(
k
)
=
δ
(
k
+
1
)
D
(
k
)
W
−
e
(
k
)
,
k
=
1
,
…
,
K
−
1
\begin{aligned} \delta(K) &= - e(K) \\ \delta(k) &= \delta(k+1)D(k)W - e(k), \\ k&=1,\ldots,K-1 \end{aligned}
δ(K)δ(k)k=−e(K)=δ(k+1)D(k)W−e(k),=1,…,K−1
所以
d
E
d
w
=
−
δ
∂
g
∂
w
=
−
[
δ
(
1
)
…
δ
(
K
)
]
[
D
(
0
)
X
(
0
)
D
(
1
)
X
(
1
)
⋮
D
(
K
−
1
)
X
(
K
−
1
)
]
=
−
∑
k
=
1
K
δ
(
k
)
D
(
k
−
1
)
X
(
k
−
1
)
\begin{aligned} \frac{dE}{dw} &= - \delta \frac{\partial g}{\partial w} \\ &= - \begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix}\begin{bmatrix} D(0)X(0)\\ D(1)X(1) \\ \vdots \\ D(K-1)X(K-1) \end{bmatrix}\\ &= -\sum_{k=1}^K \delta(k)D(k-1)X(k-1) \end{aligned}
dwdE=−δ∂w∂g=−[δ(1)…δ(K)]⎣⎢⎢⎢⎡D(0)X(0)D(1)X(1)⋮D(K−1)X(K−1)⎦⎥⎥⎥⎤=−k=1∑Kδ(k)D(k−1)X(k−1)
其中
δ
(
k
)
D
(
k
−
1
)
X
(
k
−
1
)
=
[
δ
1
(
k
)
…
δ
N
(
k
)
]
1
×
N
[
f
′
(
w
1
T
x
(
k
−
1
)
)
0
⋱
0
f
′
(
w
N
T
x
(
k
−
1
)
)
]
N
×
N
[
x
T
(
k
−
1
)
⋱
x
T
(
k
−
1
)
]
N
×
N
2
=
[
δ
1
(
k
)
f
′
(
w
1
T
x
(
k
−
1
)
)
x
T
(
k
−
1
)
…
δ
N
(
k
)
f
′
(
w
N
T
x
(
k
−
1
)
)
x
T
(
k
−
1
)
]
1
×
N
2
\begin{aligned} & \delta(k)D(k-1)X(k-1) \\ &= \begin{bmatrix} \delta_1(k) & \ldots &\delta_N(k) \end{bmatrix}_{1\times N} \begin{bmatrix} f'(w_1^Tx(k-1)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(k-1)) \end{bmatrix}_{N\times N} \begin{bmatrix} x^T(k-1) && \\ & \ddots & \\ && x^T(k-1) \end{bmatrix}_{N\times N^2} \\ &= \begin{bmatrix} \delta_1(k) f'(w_1^Tx(k-1))x^T(k-1) & \ldots &\delta_N(k) f'(w_N^Tx(k-1))x^T(k-1) \end{bmatrix}_{1\times N^2} \end{aligned}
δ(k)D(k−1)X(k−1)=[δ1(k)…δN(k)]1×N⎣⎡f′(w1Tx(k−1))0⋱0f′(wNTx(k−1))⎦⎤N×N⎣⎡xT(k−1)⋱xT(k−1)⎦⎤N×N2=[δ1(k)f′(w1Tx(k−1))xT(k−1)…δN(k)f′(wNTx(k−1))xT(k−1)]1×N2
所以矩阵形式的梯度
d
E
d
W
∈
R
N
×
N
\frac{dE}{dW} \in R^{N\times N}
dWdE∈RN×N:
d
E
d
W
=
−
∑
k
=
1
K
[
δ
1
(
k
)
f
′
(
w
1
T
x
(
k
−
1
)
)
x
T
(
k
−
1
)
⋮
δ
N
(
k
)
f
′
(
w
N
T
x
(
k
−
1
)
)
x
T
(
k
−
1
)
]
N
×
N
=
−
∑
k
=
1
K
[
f
′
(
w
1
T
x
(
k
−
1
)
)
0
⋱
0
f
′
(
w
N
T
x
(
k
−
1
)
)
]
N
×
N
[
δ
1
(
k
)
⋮
δ
N
(
k
)
]
N
×
1
x
T
(
k
−
1
)
=
−
∑
k
=
1
K
D
(
k
−
1
)
δ
T
(
k
)
x
T
(
k
−
1
)
\begin{aligned} \frac{dE}{dW} &= -\sum_{k=1}^K \begin{bmatrix} \delta_1(k) f'(w_1^Tx(k-1))x^T(k-1) \\ \vdots \\ \delta_N(k) f'(w_N^Tx(k-1))x^T(k-1) \end{bmatrix}_{N\times N} \\ &= -\sum_{k=1}^K \begin{bmatrix} f'(w_1^Tx(k-1)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(k-1)) \end{bmatrix}_{N\times N} \begin{bmatrix} \delta_1(k) \\ \vdots \\ \delta_N(k) \end{bmatrix}_{N\times 1} x^T(k-1) \\ &= -\sum_{k=1}^K D(k-1)\delta^T(k)x^T(k-1) \end{aligned}
dWdE=−k=1∑K⎣⎢⎡δ1(k)f′(w1Tx(k−1))xT(k−1)⋮δN(k)f′(wNTx(k−1))xT(k−1)⎦⎥⎤N×N=−k=1∑K⎣⎡f′(w1Tx(k−1))0⋱0f′(wNTx(k−1))⎦⎤N×N⎣⎢⎡δ1(k)⋮δN(k)⎦⎥⎤N×1xT(k−1)=−k=1∑KD(k−1)δT(k)xT(k−1)