持续学习:(Elastic Weight Consolidation, EWC)Overcoming Catastrophic Forgetting in Neural Network

概述

原论文地址:https://arxiv.org/pdf/1612.00796.pdf

本博客参考了以下博客的理解
地址:https://blog.****.net/dhaiuda/article/details/103967676/

本博客仅是个人对此论文的理解,若有理解不当的地方欢迎大家指正。

本篇论文讲述了一种通过给权重添加正则,从而控制权重优化方向,从而达到持续学习效果的方法。其方法简单来讲分为以下三个步骤,其思想如图所示:

  • 选择出对于旧任务(old task)比较重要的权重
  • 对权重的重要程度进行排序
  • 在优化的时候,越重要的权重改变越小,保证其在小范围内改变,不会对旧任务产生较大的影响
    持续学习:(Elastic Weight Consolidation, EWC)Overcoming Catastrophic Forgetting in Neural Network
    在图中,灰色区域时旧任务的低误差区域,白色为新任务的低误差区域。如果用旧任务的权重初始化网络,用新任务的数据进行训练的话,优化的方向如蓝色箭头所示,离开了灰色区域,代表着其网络失去了在旧任务上的性能。通过控制优化方向,使得其能够处于两个区域的交集部分,便代表其在旧任务与新任务上都有良好的性能。

具体方法为:将模型的后验概率拟合为一个高斯分布,其中均值为旧任务的权重,方差为 Fisher 信息矩阵(Fisher Information Matrix)的对角元素的倒数。方差就代表了每个权重的重要程度。

1. 基础知识

1.1 基本概念

  • 灾难性遗忘(Catastrophic Forgetting):在网络顺序训练多个任务的时候,对于先前任务的重要权重无法保留。灾难性遗忘是网络结构的必然特征
  • 持续学习:在顺序学习任务的时候,不忘记之前训练过的任务。根据任务A训练网络之后,再根据任务B训练同一个网络,此时对任务A进行测试,还可以维持其性能。

1.2 贝叶斯法则

P(AB)=P(AB)P(B)P(A|B) = \frac{P(A \cap B)}{P(B)}
P(BA)=P(AB)P(A)P(B|A) = \frac{P(A \cap B)}{P(A)}

P(AB)P(B)=P(BA)P(A)P(A|B)P(B)=P(B|A)P(A)
所以可以得到
P(BA)=P(AB)P(B)P(A)P(B|A) = P(A|B)\frac{P( B)}{P(A)}

2. Elastic Weight Consolidation

2.1 参数定义

  • θ\theta:网络的参数
  • θA\theta^*_A:对于任务A,网络训练得到的最优参数
  • DD:全体数据集
  • DAD_A:任务 A 的数据集
  • DBD_B:任务 B 的数据集
  • FF:Fisher 信息矩阵
  • HH:Hessian 矩阵

2.2 EWC 方法推导

对于网络来讲,给定数据集,目的是寻找一个最优的参数,即
P(θD)P(\theta|D)
根据贝叶斯准则
P(BA)=P(AB)P(B)P(A)P(B|A) = P(A|B)\frac{P( B)}{P(A)}
可以得到最大后验概率:
P(θD)=P(Dθ)P(θ)P(D)P(\theta|D) = P(D|\theta)\frac{P( \theta)}{P(D)}
于是可以得到
logP(θD)=log(P(Dθ)P(θ)P(D))=logP(Dθ)+logP(θ)logP(D)\log P(\theta|D) = \log (P(D|\theta)\frac{P( \theta)}{P(D)})=\log P(D|\theta) + \log P( \theta) - \log P(D)
也就是论文中的公式(1)

如果这是两个任务的顺序学习,旧任务为任务 A,新任务为任务 B,那么可以数据集 DD 可以划分为 DAD_ADBD_B,则
P(θDA,DB)=P(θ,DA,DB)P(DA,DB)=P(θ,DBDA)P(DA)P(DBDA)P(DA)=P(θ,DBDA)P(DBDA)P(\theta|D_A,D_B)=\frac{P(\theta,D_A,D_B)}{P(D_A,D_B)}=\frac{P(\theta,D_B|D_A)P(D_A)}{P(D_B|D_A)P(D_A)}=\frac{P(\theta,D_B|D_A)}{P(D_B|D_A)}
又因为
P(θ,DBDA)=P(θ,DB,DA)=P(θ,DA,DB)P(DA)=P(θ,DA,DB)P(θ,DA)P(θ,DA)P(DA)=P(DBθ,DA)P(θDA)P(\theta,D_B|D_A)=P(\theta,D_B,D_A)=\frac{P(\theta,D_A,D_B)}{P(D_A)}=\frac{P(\theta,D_A,D_B)}{P(\theta,D_A)} \cdot \frac{P(\theta,D_A)}{P(D_A)}=P(D_B|\theta,D_A)P(\theta|D_A)
所以,可以得到
P(θDA,DB)=P(θ,DBDA)P(DBDA)=P(DBθ,DA)P(θDA)P(DBDA)P(\theta|D_A,D_B)=\frac{P(\theta,D_B|D_A)}{P(D_B|D_A)}=\frac{P(D_B|\theta,D_A)P(\theta|D_A)}{P(D_B|D_A)}
又因为DAD_ADBD_B 独立,所以可以得到
P(DBDA)=P(DB)P(D_B|D_A)=P(D_B)
P(DBθ,DA)=P(DBθ)P(D_B|\theta,D_A)=P(D_B|\theta)
所以
P(θDA,DB)=P(DBθ)P(θDA)P(DB)P(\theta|D_A,D_B)=\frac{P(D_B|\theta)P(\theta|D_A)}{P(D_B)}
同样对于两边取 log,可以得到
logP(θD)=logP(θDA,DB)=logP(DBθ)+logP(θDA)logP(DB)\log P(\theta|D)=\log P(\theta|D_A,D_B)= \log P(D_B|\theta)+logP(\theta|D_A)-\log P(D_B)
这个便是论文中的公式(2),也是这篇论文的核心内容。

在给定整个数据集,我们需要得到一个 θ\theta 使得概率最大,那么也就是分别优化上式的右边三项。

第一项很明显可以理解为任务B的损失函数,将其命名为 LB(θ)L_B(\theta),第三项对于 θ\theta 来讲是一个常数,那么网络的优化目标便是
maxθlogP(θD)=maxθ(LB(θ)+logP(θDA)) \mathop{max}\limits_{\theta}\log P(\theta|D)=\mathop{max}\limits_{\theta}(-L_B(\theta)+\log P(\theta|D_A))

minθ(LB(θ)logP(θDA))\mathop{min}\limits_{\theta}(L_B(\theta)-\log P(\theta|D_A))
现在,重点变成了如何优化后验概率 logP(θDA)\log P(\theta|D_A) ,作者采用了拉普拉斯近似的方法进行量化。

3. 拉普拉斯近似

由于后验概率并不容易进行衡量,所以我们将其先验 logP(DAθ)\log P(D_A|\theta) 拟合为一个高斯分布

3.1 高斯分布拟合

令先验 logP(DAθ)\log P(D_A|\theta) 服从高斯分布
P(DAθ)N(μ,σ)P(D_A|\theta) \sim N(\mu, \sigma)
那么由高斯分布的公式可以得到
P(DAθ)=12πσe(θμ)22σ2P(D_A|\theta)=\frac{1}{\sqrt{2 \pi}\sigma} e^{-\frac{(\theta-\mu)^2}{2\sigma^2}}
那么,可以得到
logP(DAθ)=log12πσ(θμ)22σ2\log P(D_A|\theta)=\log \frac{1}{\sqrt{2 \pi}\sigma} -\frac{(\theta-\mu)^2}{2\sigma^2}

f(θ)=logP(DAθ)f(\theta)=\log P(D_A|\theta)
θ=θA\theta = \theta_A^* 处进行泰勒展开,可以得到
f(θA)=0f'(\theta_A^*)=0
f(θ)=f(θA)+f(θA)(θθA)+f(θA)(θθA)22+o(θA)f(\theta)=f(\theta_A^*)+f'(\theta_A^*)(\theta-\theta_A^*)+f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2}+o(\theta_A^*)
所以
log12πσ(θμ)22σ2f(θA)+f(θA)(θθA)22\log \frac{1}{\sqrt{2 \pi}\sigma} -\frac{(\theta-\mu)^2}{2\sigma^2}\approx f(\theta_A^*)+f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2}
其中,log12πσ\log \frac{1}{\sqrt{2 \pi}\sigma}f(θA)f(\theta_A^*) 都是常数,可以得到
(θμ)22σ2=f(θA)(θθA)22-\frac{(\theta-\mu)^2}{2\sigma^2}= f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2}
因此,可以得到
μ=θA\mu = \theta_A^*
σ2=1f(θA)\sigma^2=-\frac{1}{f''(\theta_A^*)}
所以,可以得到
P(DAθ)N(θA,1f(θA))P(D_A|\theta) \sim N(\theta_A^*, -\frac{1}{f''(\theta_A^*)})
根据贝叶斯准则,
P(θDA)=P(PA,θ)P(θ)P(A)P(\theta|D_A) = \frac{P(P_A,\theta)P(\theta)}{P(A)}
其中,P(θ)P(\theta) 符合均匀分布,P(DA)P(D_A) 为常数,所以
P(θDA)N(θA,1f(θA))P(\theta|D_A) \sim N(\theta_A^*, -\frac{1}{f''(\theta_A^*)})
此时,优化函数
minθ(LB(θ)logP(θDA))\mathop{min}\limits_{\theta}(L_B(\theta)-\log P(\theta|D_A))
可以变换为
minθ(LB(θ)f(θA)(θθA)22\mathop{min}\limits_{\theta}(L_B(\theta)- f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2}
对于一个batch来说,即为
minθ(LB(θ)ifi(θA)(θiθA,i)22\mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i f''_i(\theta_A^*)\frac{(\theta_i-\theta_{A,i}^*)^2}{2}
那么 f(θA)f''(\theta_A^*) 该如何求解呢?

3.2 Fisher Information Matrix

3.2.1 Fisher Information Matrix 的含义

Fisher information 是概率分布梯度的协方差。为了更好的说明Fisher Information matrix 的含义,这里定义一个得分函数 SS
S(θ)=logp(xθ)S(\theta)=\nabla \log p(x|\theta)

Ep(xθ)[S(θ)]=Ep(xθ)[logp(xθ)]=logp(xθ)p(xθ)dθ=p(xθ)p(xθ)p(xθ)dθ=p(xθ)dθ=p(xθ)dθ=1=0\begin{aligned} \mathop{E}\limits_{p(x|\theta)}[S(\theta)] &=\mathop{E}\limits_{p(x|\theta)}[\nabla \log p(x|\theta)] \\ &= \int \nabla \log p(x|\theta) \cdot p(x|\theta) d\theta \\ &= \int \frac{\nabla p(x|\theta)}{p(x|\theta)} \cdot p(x|\theta) d\theta \\ &= \int \nabla p(x|\theta) d\theta \\ &= \nabla \int p(x|\theta) d\theta \\ & = \nabla 1=0 \end{aligned}
那么 Fisher Information matrix FF
F=Ep(Xθ)[(S(θ)0)(S(θ)0)T]F = \mathop{E}\limits_{p(X|\theta)}[(S(\theta)-0)(S(\theta)-0)^T]
对于每一个batch的数据 X={x1,x2,,xn}X = \{x_1,x_2,\cdots ,x_n\},则其定义为
F=1Ni=1Nlogp(xiθ)logp(xiθ)TF = \frac{1}{N}\sum_{i=1}^N \nabla \log p(x_i|\theta) \nabla \log p(x_i|\theta)^T

3.2.2 Fisher 信息矩阵与 Hessian 矩阵

Hessian矩阵为
Hlogp(xθ)=J(logp(xtheta))=J(p(xθ)p(xθ))=Hp(xθ)p(xθ)p(xθ)p(xθ)Tp(xθ)p(xθ)=Hp(xθ)p(xθ)p(xθ)p(xθ)Tp(xθ)p(xθ)=Hp(xθ)p(xθ)(p(xθ)p(xθ))(p(xθ)p(xθ))T\begin{aligned} H_{\log p(x|\theta)} &= J(\nabla \log p(x|theta)) = J(\frac{ \nabla p(x|\theta)}{ p(x|\theta)}) \\ &= \frac{H_{ p(x|\theta)} p(x|\theta)- \nabla p(x|\theta) \nabla p(x|\theta)^T}{ p(x|\theta) p(x|\theta)} \\ &= \frac{H_{ p(x|\theta)}}{ p(x|\theta)}-\frac{ \nabla p(x|\theta) \nabla p(x|\theta)^T}{ p(x|\theta) p(x|\theta)} \\ &= \frac{H_{ p(x|\theta)}}{ p(x|\theta)}-(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })^T \end{aligned}
Fisher 信息阵为
Ep(xθ)[Hlogp(xθ)]=Ep(xθ)[Hp(xθ)p(xθ)]Ep(xθ)[(p(xθ)p(xθ))(p(xθ)p(xθ))T]=Hp(xθ)p(xθ)p(xθ)dθEp(xθ)[logp(xθ)logp(xθ)T]=Hp(xθ)p(xθ)p(xθ)dθEp(xθ)[(S(θ)0)(S(θ)0)T]=Hp(xθ)dθF=H1F=0F=F\begin{aligned} \mathop{E}\limits_{p(x|\theta)}[H_{\log p(x|\theta)}] &= \mathop{E}\limits_{p(x|\theta)}[\frac{H_{ p(x|\theta)}}{ p(x|\theta)}]- \mathop{E}\limits_{p(x|\theta)}[(\frac{ \nabla p(x|\theta)}{p(x|\theta) })(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })^T] \\ &= \int \frac{H_{ p(x|\theta)}}{ p(x|\theta)} p(x|\theta) d\theta - \mathop{E}\limits_{p(x|\theta)}[\nabla \log p(x|\theta) \nabla \log p(x|\theta)^T] \\ & = \int \frac{H_{ p(x|\theta)}}{ p(x|\theta)} p(x|\theta) d\theta - \mathop{E}\limits_{p(x|\theta)}[(S(\theta)-0)(S(\theta)-0)^T] \\ &= \int {H_{ p(x|\theta)}} d\theta -F \\ &= H_1 -F = 0-F \\ &=-F \end{aligned}
所以,Fisher 信息矩阵是 Hessian 矩阵的负期望。

因为 f(x)f''(x)HH 的对角线元素,所以 1f(x)-\frac{1}{f''(x)}FF对角线元素的倒数。
所以,损失函数
minθ(LB(θ)ifi(θA)(θiθA,i)22\mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i f''_i(\theta_A^*)\frac{(\theta_i-\theta_{A,i}^*)^2}{2}
可以变为
minθ(LB(θ)iFi(θiθA,i)22\mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i F_i\frac{(\theta_i-\theta_{A,i}^*)^2}{2}
引入超参 λ\lambda 衡量两项的重要程度,可以得到最终的损失
minθ(LB(θ)λ2iFi(θiθA,i)2\mathop{min}\limits_{\theta}(L_B(\theta)- \frac{\lambda}{2}\sum_i F_i(\theta_i-\theta_{A,i}^*)^2
上式即为论文中的公式(3)
到此,论文的核心内容就已经结束了,后面的应用及实验结果在此不再展示。