RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)

RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)

问题描述

考虑模型循环网络模型:
x ( k ) = f [ W x ( k − 1 ) ] (1) x(k) = f[Wx(k-1)] \tag1{} x(k)=f[Wx(k1)](1)
其中 x ( k ) ∈ R N x(k) \in R^N x(k)RN表示网络节点状态, W ∈ R N × N W\in R^{N\times N} WRN×N表示网络结点之间相互连接的权重,网络的输出节点为 { x i ( k ) ∣ i ∈ O } \{x_i(k)| i\in O\} {xi(k)iO} O O O为所有输出(或称“观测”)单元的下标集合

RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)
训练的目标是为了减少观测状态和预期值之间误差,即最小化损失函数:
E = 1 2 ∑ k = 1 K ∑ i ∈ O [ x i ( k ) − d i ( k ) ] 2 (2) E = \frac{1}{2}\sum_{k=1}^K \sum_{i\in O} [x_i(k) - d_i(k)]^2 \tag{2} E=21k=1KiO[xi(k)di(k)]2(2)
其中 d i ( k ) d_i(k) di(k) 表示 k k k 时刻第 i i i 个节点的预期值

采用梯度下降法更新 W W W:
W + = W − η d E d W W_+ = W - \eta \frac{dE}{dW} W+=WηdWdE

符号约定

W ≡ [ —– w 1 T —– ⋮ —– w N T —– ] N × N W \equiv \begin{bmatrix} \text{-----} w_1^T \text{-----} \\ \vdots \\ \text{-----} w_N^T \text{-----} \end{bmatrix}_{N\times N} W—–w1T—–—–wNT—–N×N
将矩阵 W W W 拉成列向量,记为 w w w
w = [ w 1 T , ⋯   , w N T ] T ∈ R N 2 w = [w_1^T, \cdots, w_N^T]^T \in R^{N^2} w=[w1T,,wNT]TRN2
把所有时间的状态拼成列向量,记为 x x x
x = [ x T ( 1 ) , ⋯   , x T ( K ) ] T ∈ R N K x = [x^T(1), \cdots, x^T(K)]^T \in R^{NK} x=[xT(1),,xT(K)]TRNK
将RNN 的训练视为约束优化问题,(1)式转化成约束条件:
g ( k ) ≡ f [ W x ( k − 1 ) ] − x ( k ) = 0 , k = 1 , … , K (3) g(k) \equiv f[Wx(k-1)] - x(k) =0, \quad k=1,\ldots ,K \tag{3} g(k)f[Wx(k1)]x(k)=0,k=1,,K(3)

g = [ g T ( 1 ) , … , g T ( K ) ] T ∈ R N K g = [g^T(1), \ldots, g^T(K)]^T \in R^{NK} g=[gT(1),,gT(K)]TRNK


0 = d g ( x ( w ) , w ) d w = ∂ g ( x ( w ) , w ) ∂ x ∂ x ( w ) ∂ w + ∂ g ( x ( w ) , w ) ∂ w (4) 0 = \frac{dg(x(w),w)}{dw} = \frac{\partial g(x(w),w)}{\partial x}\frac{\partial x(w)}{\partial w} + \frac{\partial g(x(w),w)}{\partial w} \tag{4} 0=dwdg(x(w),w)=xg(x(w),w)wx(w)+wg(x(w),w)(4)
d E d w = ∂ E ∂ x ( ∂ g ∂ x ) − 1 ∂ g ∂ w (5) \frac{dE}{dw} = \frac{\partial E}{\partial x} \left(\frac{\partial g}{\partial x}\right)^{-1} \frac{\partial g}{\partial w} \tag{5} dwdE=xE(xg)1wg(5)
RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)

Atiya-Parlos 算法

以上是经典的梯度下降法的思维,但是 Atiya-Parlos 提出了另一种优化思路:不是朝着参数的梯度方向更新,但仍使代价函数下降

该算法的思想是互换网络状态 x ( k ) x(k) x(k) 和权重矩阵 W W W 的作用:将状态视为控制变量,并根据 x ( k ) x(k) x(k) 的变化确定权重的变化。 换句话说,我们计算 E E E 相对于状态 x ( k ) x(k) x(k) 的梯度,并假设状态在该梯度的负方向 Δ x i ( k ) = − η ∂ E ∂ x i ( k ) \displaystyle{\Delta x_i(k) = -\eta\frac{\partial E}{\partial x_i(k)} } Δxi(k)=ηxi(k)E 上有微小变化。

接下来,我们 确 定 权 重 W 的 变 化 Δ w , 以 使 由 权 重 变 化 导 致 的 状 态 变 化 尽 可 能 地 接 近 目 标 变 化 Δ x \textcolor{red}{确定权重 W 的变化 \Delta w,以使由权重变化导致的状态变化尽可能地接近目标变化 \Delta x} WΔw使Δx

该算法的细节如下:
Δ x = − η ( ∂ E ∂ x i ) T = − η e T = − η [ e ( 1 ) , … , e ( K ) ] T \begin{aligned} \Delta x &= -\eta \left(\frac{\partial E}{\partial x_i} \right)^T \\ &= -\eta e^T\\ &= -\eta [e(1), \ldots, e(K)]^T \end{aligned} Δx=η(xiE)T=ηeT=η[e(1),,e(K)]T

将约束 g = 0 g = 0 g=0 改写成:
h ( k ) ≡ W x ( k − 1 ) − f − 1 ( x ( k ) ) = 0 , k = 1 , 2 , … , K h(k) \equiv W x(k-1) - f^{-1}(x(k)) = 0, \quad k=1,2,\ldots,K h(k)Wx(k1)f1(x(k))=0,k=1,2,,K
由此约束条件得:
∂ h ∂ x Δ x = − ∂ h ∂ w Δ w \frac{\partial h}{\partial x} \Delta x = - \frac{\partial h}{\partial w} \Delta w xhΔx=whΔw
故已知 Δ x \Delta x Δx 时,可得:
( ∂ h ∂ w ) T ( ∂ h ∂ x ) Δ x = − ( ∂ h ∂ w ) T ( ∂ h ∂ w ) Δ w \left(\frac{\partial h}{\partial w}\right)^T\left( \frac{\partial h}{\partial x}\right) \Delta x = - \left(\frac{\partial h}{\partial w}\right)^T \left(\frac{\partial h}{\partial w}\right) \Delta w (wh)T(xh)Δx=(wh)T(wh)Δw Δ w = − [ ( ∂ h ∂ w ) T ( ∂ h ∂ w ) ] − 1 ( ∂ h ∂ w ) T ( ∂ h ∂ x ) Δ x \Delta w = -\left[\left(\frac{\partial h}{\partial w}\right)^T \left(\frac{\partial h}{\partial w}\right)\right]^{-1} \left(\frac{\partial h}{\partial w}\right)^T\left( \frac{\partial h}{\partial x}\right) \Delta x Δw=[(wh)T(wh)]1(wh)T(xh)Δx
需要注意逆矩阵不一定存在,故
Δ w = − [ ( ∂ h ∂ w ) T ( ∂ h ∂ w ) + ϵ I ] − 1 ( ∂ h ∂ w ) T ( ∂ h ∂ x ) Δ x \Delta w = -\left[\left(\frac{\partial h}{\partial w}\right)^T \left(\frac{\partial h}{\partial w}\right) + \epsilon I\right]^{-1} \left(\frac{\partial h}{\partial w}\right)^T\left( \frac{\partial h}{\partial x}\right) \Delta x Δw=[(wh)T(wh)+ϵI]1(wh)T(xh)Δx
这就是权重 W W W 的更新规则

计算细节

  1. 计算 ∂ h ∂ w \frac{\partial h}{\partial w} wh
    ∂ h ∂ w = [ ∂ h ( 1 ) ∂ w ⋮ ∂ h ( K ) ∂ w ] = [ ∂ W x ( 0 ) ∂ w ⋮ ∂ W x ( K − 1 ) ∂ w ] \frac{\partial h}{\partial w} = \begin{bmatrix} \frac{\partial h(1)}{\partial w}\\ \vdots \\ \frac{\partial h(K)}{\partial w} \end{bmatrix} = \begin{bmatrix} \frac{\partial Wx(0)}{\partial w}\\ \vdots \\ \frac{\partial Wx(K-1)}{\partial w} \end{bmatrix}\\ wh=wh(1)wh(K)=wWx(0)wWx(K1)
    其中
    ∂ W x ( k ) ∂ w = [ ∂ w 1 T x ( k ) ∂ w ⋮ ∂ w N T x ( k ) ∂ w ] = [ x T ( k ) x T ( k ) ⋱ x T ( k ) ] N × N 2 ≜ X ( k ) \begin{aligned} \frac{\partial Wx(k)}{\partial w} &= \begin{bmatrix} \frac{\partial w_1^Tx(k)}{\partial w}\\ \vdots \\ \frac{\partial w_N^Tx(k)}{\partial w} \end{bmatrix}\\\\ &= \begin{bmatrix} x^T(k) &&& \\ & x^T(k)&& \\ && \ddots & \\ &&& x^T(k) \end{bmatrix}_{N\times N^2} \\\\ &\triangleq X(k) \end{aligned} wWx(k)=ww1Tx(k)wwNTx(k)=xT(k)xT(k)xT(k)N×N2X(k)

    ∂ h ∂ w = [ X ( 0 ) ⋮ X ( K − 1 ) ] N K × N 2 \frac{\partial h}{\partial w} = \begin{bmatrix} X(0)\\ \vdots \\ X(K-1) \end{bmatrix}_{NK \times N^2} wh=X(0)X(K1)NK×N2
    ( ∂ h ∂ w ) T ( ∂ h ∂ w ) = [ X T ( 0 ) ⋯ X T ( K − 1 ) ] [ X ( 0 ) ⋮ X ( K − 1 ) ] = ∑ k = 0 K − 1 X T ( k ) X ( k ) = [ ∑ k = 0 K − 1 x ( k ) x T ( k ) ∑ k = 0 K − 1 x ( k ) x T ( k ) ⋱ ∑ k = 0 K − 1 x ( k ) x T ( k ) ] N 2 × N 2 \begin{aligned} &\left(\frac{\partial h}{\partial w}\right)^T \left(\frac{\partial h}{\partial w}\right) \\ &= \begin{bmatrix} X^T(0) & \cdots & X^T(K-1) \end{bmatrix} \begin{bmatrix} X(0)\\ \vdots \\ X(K-1) \end{bmatrix} \\\\ &= \sum_{k=0}^{K-1} X^T(k)X(k) \\\\ &=\begin{bmatrix} \sum_{k=0}^{K-1} x(k)x^T(k) &&& \\ & \sum_{k=0}^{K-1} x(k)x^T(k)&& \\ && \ddots & \\ &&& \sum_{k=0}^{K-1} x(k)x^T(k) \end{bmatrix}_{N^2 \times N^2} \end{aligned} (wh)T(wh)=[XT(0)XT(K1)]X(0)X(K1)=k=0K1XT(k)X(k)=k=0K1x(k)xT(k)k=0K1x(k)xT(k)k=0K1x(k)xT(k)N2×N2

  2. 计算 ∂ h ∂ x \frac{\partial h}{\partial x} xh
    ∂ h ∂ x = [ ∂ h ( 1 ) ∂ x ⋮ ∂ h ( K ) ∂ x ] = [ ∂ W x ( 0 ) − f − 1 [ x ( 1 ) ] ∂ x ⋮ ∂ W x ( K − 1 ) − f − 1 [ x ( K ) ] ∂ x ] \frac{\partial h}{\partial x} = \begin{bmatrix} \frac{\partial h(1)}{\partial x} \\ \vdots \\ \frac{\partial h(K)}{\partial x} \end{bmatrix} = \begin{bmatrix} \frac{\partial Wx(0) - f^{-1}[x(1)] }{\partial x}\\ \vdots \\ \frac{\partial Wx(K-1) - f^{-1}[x(K)] }{\partial x} \end{bmatrix}\\ xh=xh(1)xh(K)=xWx(0)f1[x(1)]xWx(K1)f1[x(K)]
    其中
    ∂ ∂ x ( W x ( k − 1 ) − f − 1 [ x ( k ) ] ) = [ ⋯ ∂ ( W x ( k − 1 ) − f − 1 [ x ( k ) ] ) ∂ x ( k − 1 ) ∂ ( W x ( k − 1 ) − f − 1 [ x ( k ) ] ) ∂ x ( k ) ⋯ ] = [ ⋯ ∂ W x ( k − 1 ) ∂ x ( k − 1 ) − ∂ f − 1 [ x ( k ) ] ∂ x ( k ) ⋯ ] = [ ⋯ W − ∂ f − 1 [ x ( k ) ] ∂ x ( k ) ⋯ ] \begin{aligned} &\frac{\partial }{\partial x} \left(Wx(k-1) - f^{-1}[x(k)]\right) \\\\ &= \left[\begin{array}{c:c:c:c} \cdots & \frac{\partial \left(Wx(k-1) - f^{-1}[x(k)]\right)}{\partial x(k-1)} & \frac{\partial \left(Wx(k-1) - f^{-1}[x(k)]\right)}{\partial x(k)} & \cdots \end{array}\right] \\\\ &= \left[\begin{array}{c:c:c:c} \cdots & \frac{\partial Wx(k-1)}{\partial x(k-1)} & -\frac{\partial f^{-1}[x(k)]}{\partial x(k)} & \cdots \end{array}\right] \\\\ &= \left[\begin{array}{c:c:c:c} \cdots & W & -\frac{\partial f^{-1}[x(k)]}{\partial x(k)} & \cdots \end{array}\right] \\\\ \end{aligned} x(Wx(k1)f1[x(k)])=[x(k1)(Wx(k1)f1[x(k)])x(k)(Wx(k1)f1[x(k)])]=[x(k1)Wx(k1)x(k)f1[x(k)]]=[Wx(k)f1[x(k)]]
    而其中
    ∂ f − 1 [ x ( j ) ] ∂ x ( j ) = [ ∂ f − 1 ( x 1 ( j ) ) ∂ x 1 ( j ) … ∂ f − 1 ( x 1 ( j ) ) ∂ x N ( j ) ⋮ ⋱ ⋮ ∂ f − 1 ( x N ( j ) ) ∂ x 1 ( j ) … ∂ f − 1 ( x N ( j ) ) ∂ x N ( j ) ] = [ ∂ f − 1 ( x 1 ( j ) ) ∂ x 1 ( j ) … 0 ⋮ ⋱ ⋮ 0 … ∂ f − 1 ( x N ( j ) ) ∂ x N ( j ) ] ( 反 函 数 求 导 ) = [ 1 f ′ ( w 1 T x ( j − 1 ) ) 0 ⋱ 0 1 f ′ ( w N T x ( j − 1 ) ) ] = D − 1 ( j − 1 ) \begin{aligned} \frac{\partial f^{-1}[x(j)]}{\partial x(j)} & = \begin{bmatrix} \frac{\partial f^{-1}(x_1(j))}{\partial x_1(j)} & \ldots & \frac{\partial f^{-1}(x_1(j))}{\partial x_N(j)}\\ \vdots & \ddots & \vdots\\ \frac{\partial f^{-1}(x_N(j))}{\partial x_1(j)}& \ldots & \frac{\partial f^{-1}(x_N(j))}{\partial x_N(j)} \end{bmatrix}\\\\ & = \begin{bmatrix} \frac{\partial f^{-1}(x_1(j))}{\partial x_1(j)} & \ldots & 0\\ \vdots & \ddots & \vdots\\ 0& \ldots & \frac{\partial f^{-1}(x_N(j))}{\partial x_N(j)} \end{bmatrix}\\\\ \color{red}{(反函数求导)} &= \begin{bmatrix} \frac{1}{ f'(w_1^Tx(j-1))} & &0\\ & \ddots & \\ 0& & \frac{1}{ f'(w_N^Tx(j-1))} \end{bmatrix} \\\\ &= D^{-1}(j-1) \end{aligned} x(j)f1[x(j)]()=x1(j)f1(x1(j))x1(j)f1(xN(j))xN(j)f1(x1(j))xN(j)f1(xN(j))=x1(j)f1(x1(j))00xN(j)f1(xN(j))=f(w1Tx(j1))100f(wNTx(j1))1=D1(j1)

注:(反函数求导) 若 y = f ( x ) , x = f − 1 ( y ) y = f(x), x = f^{-1}(y) y=f(x),x=f1(y),则: d f − 1 ( y ) y = d x d y = 1 / ( d y d x ) = 1 f ′ ( x ) \frac{df^{-1}(y)}{y} = \frac{dx}{dy} = 1/(\frac{dy}{dx}) = \frac{1}{f'(x)} ydf1(y)=dydx=1/(dxdy)=f(x)1
故由 x i ( j ) = f ( w i T x ( j − 1 ) ) x_i(j) = f(w_i^Tx(j-1)) xi(j)=f(wiTx(j1))
∂ f − 1 ( x i ( j ) ) ∂ x i ( j ) = 1 f ′ ( w i T x ( j − 1 ) ) \frac{\partial f^{-1}(x_i(j))}{\partial x_i(j)} =\frac{1}{ f'(w_i^Tx(j-1))} xi(j)f1(xi(j))=f(wiTx(j1))1

综上所述
∂ h ∂ x = [ − D − 1 ( 0 ) 0 0 … 0 W − D − 1 ( 1 ) 0 … 0 0 W − D − 1 ( 2 ) … 0 ⋮ ⋮ ⋮ ⋱ ⋮ 0 0 0 W − D − 1 ( K − 1 ) ] N K × N K \frac{\partial h}{\partial x} = \begin{bmatrix} -D^{-1}(0) & 0& 0 &\ldots & 0\\ W & -D^{-1}(1) & 0 &\ldots & 0 \\ 0 & W & -D^{-1}(2) & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & W& -D^{-1}(K-1) \end{bmatrix}_{NK\times NK} xh=D1(0)W000D1(1)W000D1(2)0W000D1(K1)NK×NK

  1. 计算 ∂ h ∂ x Δ x \frac{\partial h}{\partial x}\Delta x xhΔx

    γ = − 1 η ∂ h ∂ x Δ x \gamma = -\frac{1}{\eta} \frac{\partial h}{\partial x} \Delta x γ=η1xhΔx
    它和待求量只差一个比例系数,代入
    Δ x = − η [ e ( 1 ) , … , e ( K ) ] T \Delta x = -\eta [e(1), \ldots, e(K)]^T Δx=η[e(1),,e(K)]T

    γ = ∂ h ∂ x [ e ( 1 ) , … , e ( K ) ] T = [ − D − 1 ( 0 ) 0 0 … 0 W − D − 1 ( 1 ) 0 … 0 0 W − D − 1 ( 2 ) … 0 ⋮ ⋮ ⋮ ⋱ ⋮ 0 0 0 W − D − 1 ( K − 1 ) ] [ e T ( 1 ) e T ( 2 ) e T ( 3 ) ⋮ e T ( K ) ] e 是 行 向 量 = [ − D − 1 ( 0 ) e T ( 1 ) W e ( 1 ) − D − 1 ( 1 ) e T ( 2 ) W e ( 2 ) − D − 1 ( 2 ) e T ( 3 ) ⋮ W e ( K − 1 ) − D − 1 ( K − 1 ) e T ( K ) ] \begin{aligned} \gamma &= \frac{\partial h}{\partial x} [e(1), \ldots, e(K)]^T \\\\ &= \begin{bmatrix} -D^{-1}(0) & 0& 0 &\ldots & 0\\ W & -D^{-1}(1) & 0 &\ldots & 0 \\ 0 & W & -D^{-1}(2) & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & W& -D^{-1}(K-1) \end{bmatrix} \begin{bmatrix} e^T(1)\\ e^T(2) \\ e^T(3) \\ \vdots \\ e^T(K) \end{bmatrix} \color{red}{e是行向量}\\\\ &= \begin{bmatrix} -D^{-1}(0)e^T(1)\\ We(1)-D^{-1}(1)e^T(2) \\ We(2)-D^{-1}(2)e^T(3) \\ \vdots \\ We(K-1) -D^{-1}(K-1)e^T(K) \end{bmatrix} \end{aligned} γ=xh[e(1),,e(K)]T=D1(0)W000D1(1)W000D1(2)0W000D1(K1)eT(1)eT(2)eT(3)eT(K)e=D1(0)eT(1)We(1)D1(1)eT(2)We(2)D1(2)eT(3)We(K1)D1(K1)eT(K)

算法总结

RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)

作者简介

《New Results on Recurrent Network Training: Unifying the Algorithms and Accelerating Convergence》
RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)

RNN 训练算法 —— APRL (Atiya-Parlos recurrent learning)