时间序列的反向传播算法(BPTT)

时间序列的反向传播算法

BPTT : Back-Propagation Through Time
时间序列的反向传播算法(BPTT)
LU=tLtU\frac{\partial L}{\partial U} = \sum_t\frac{\partial L_t}{\partial U}

例如t=4 时,

L4U=L4y4y4h4h4U\frac{\partial L_4}{\partial U} =\frac{\partial L_4}{\partial y_4} \frac{\partial y_4}{\partial h_4} \frac{\partial h_4}{\partial U}

where h4=tanh(Wh3+Ux4)h_4 = tanh(Wh_3 + Ux_4)

注意到 h3也依赖U
LtU=s=0tLtytythththshsU\frac{\partial L_t}{\partial U} = \sum_{s=0}^t\frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial h_t} \frac{\partial h_t}{\partial h_s}\frac{\partial h_s}{\partial U}

参数共享是双刃剑,网络预测时具有平稳性,但是梯度计算的时候会有依赖。

随着t和s的距离越来越大,梯度传播的计算,长时序的依赖不足。
hths=htht1ht1ht2...hs+1hs\frac{\partial h_t}{\partial h_s} = \frac{\partial h_t}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial h_{t-2}} ... \frac{\partial h_{s+1}}{\partial h_{s}}

Truncated BPTT

时间序列的反向传播算法(BPTT)
BPTT 只在子序列的内部去做反向传播,只关心内部的计算。在实际的计算中,很少有人用 full BPTT,一般是使用 Truncated BPTT。