《Incremental Classifier Learning with Generative Adversarial Networks》 阅读笔记

原文链接: Incremental Classifier Learning with Generative Adversarial Networks

本文主要是解决增量学习中灾难性遗忘的问题,文中指出灾难性遗忘主要是因为过去的数据在训练时得不到。通常的方法是先保留一部分过去类别的样本,然后配合蒸馏正则化来平衡新旧类别之间的关系。但是这些方法主要有四个问题
1)Loss 函数的设计对分类器来说都不是很有效
2)新旧类别样本不均衡
3)预先保留的样本量有限,可能与新类别差别很小
4)隐私保护不允许保留过去的样本

本文针对以上问题提出了对应的解决方案

1)使用了新的 Loss 函数,其中结合了交叉熵与蒸馏损失函数
2)用简单方法估计并消除新旧类样本的不均衡
3)用 GANs 来生成历史数据,并在生成时选择特征
4)用 GANs 生成不涉及隐私问题,因为它不时直接拷贝任何的图像

2.相关工作

在深度学习之前,人们利用线性分类器,弱分类器组合,最近邻分类器等来研究增量学习。随着深度学习的发展,出现了很多基于深度学习的模型,这些工作根据它们是否需要之前的数据而分为两类:

a) 第一类是不需要任何过去数据的。 [9] 提出了适用于域迁移学习的方法,它们固定了旧任务的输出层,并保持特征提取部分的共享权重不变。[10] 通过约束重要的权重的变化,训练新类的时候让它们仍接近于原先的数值,这种方法有一个问题就是在新老任务共享的参数部分会发生冲突。[13] 则是使用知识蒸馏[8] 的方法来保持模型在旧任务上的性能。[24] 也是用了知识蒸馏,但是它是用在了目标检测领域而不是分类。本文的方法属于这一种,虽然现在GANs 的生成图片性能还不是很好,但是本文通过实验证明,使用 GANs 还是比其他同样不需要加入旧样本的方法要好
b) 第二类是需要过去的数据。 [22][14][25]等,有的方法是需要全部旧数据,有的方法则需要部分旧数据,但是这一类的方法无疑都面临,现实中由于隐私保护等问题可得不到原始任务的数据的情况,使用条件有限。而用 GANs 生成是一个很好的解决方法。

3.增量学习

虽然增量学习严格的定义并不使用过去的数据,但是 iCaRL[22] 证明再训练中使用少量过去数据更有效。

目前增量学习面临2大主要挑战:
i)  需要保持在过去 n 类任务上的性能
ii) 平衡好新类与旧类

本文主要是对标 LwF[13] 和 iCaRL[22] 两个方法。

《Incremental Classifier Learning with Generative Adversarial Networks》 阅读笔记

LwF 使用了蒸馏来解决挑战 i),并使用权重衰减(weight decay)来解决挑战 ii),但是有一个问题就是在使用蒸馏无法保证新学习后的分类器在旧数据上可以得到和旧分类器在旧数据上一样的预测结果,而且找到一个合适的 weight decay 很困难。这两个问题可以通过选择少量的旧数据集的真实样例而大大改善。

iCaRL 则是通过从旧数据集中仔细的筛选样例来解决挑战 i),对每一类使用二值熵损失函数来解决挑战 ii)。但是在类与类之间,二值熵损失函数没有交叉熵损失函数有效。

《Incremental Classifier Learning with Generative Adversarial Networks》 阅读笔记

损失函数

所以本文通过结合 LwF 和 iCaRL 的长处提出了一个新的损失函数(交叉熵损失函数 + 蒸馏损失函数),这里使用了一个系数 λ 来平衡二者

《Incremental Classifier Learning with Generative Adversarial Networks》 阅读笔记

消除偏差

一般情况下因为新增的 m 类样本量远大于 n 类旧样本量(这里可能也是指用 GANs 生成的少量样本),这就会发导致新训练出来的分类器结果偏向新的 m 类,本文就是使用一个 0-1 之间的常量系数 β 乘以新加入的 m 类的最后的分类结果(概率值)来平衡二者。

《Incremental Classifier Learning with Generative Adversarial Networks》 阅读笔记

实验证明本文这种方法优于在 CIFAR-100 上 iCaRL ,并对选取合适的 β 做了大量实验。

4. 用 GANs 生成样本

人们手工选择样本的一个问题是,选择 exemplars 主要缺点是数据之间的分布可能不同,GANs 可以学习数据分布。

首先使用所有旧 n 类的样本来训练一个生成器 xg = G(z),这里的 z 是随机噪声,然后标签就直接用分离器的结果来生成图像。
文中使用了最普通的 GAN 而不是 conditional GANs[16],因为[16] 不适用于类别数很多或者每个类别样本量很少的情况[19]。 文中使用的是 DCGANs[21]作为生成器,并使用 WGANs[2] 中的 earth-mover distance 来作为损失函数。

最后设定了一个阈值,然后在阈值 θ 内选在 n 类上满足最大似然估计的样本

5. 实验

后面文中给出了大量的实验,这里就不具体说明了,感兴趣的同学可以移步至原论文深入阅读。