论文阅读笔记《TransMatch: A Transfer-Learning Scheme for Semi-Supervised Few-Shot Learning》
核心思想
本文提出一种基于迁移学习的半监督小样本学习算法(TransMatch)。整个算法并不复杂,首先利用带有标签的基础数据集训练特征提取网络,然后用该特征提取网络为新的数据集初始化分类器权重,最后用半监督学习的方式进一步更新整个网络。整个流程如下图所示
第一阶段:预训练阶段。这一阶段没有什么值得介绍的,就是用带有标签的基础数据集对特征提取器进行训练。
第二阶段:分类器权重“生成”阶段(Classifier Weight Imprinting)。在这一阶段,使用已经预训练好的特征提取网络,对新的带有标签的数据集进行特征提取,并生成对应的分类器权重。本文采用一种叫做Weight Imprinting的方法来生成分类器的权重,方法如下
式中表示类别对应的分类器权重,表示特征提取网络,表示类别中第个样本。通过上式得到每个类别对应的权重后,再通过计算余弦距离的方式进行分类
式中对于样本,分别计算其与个类别权重之间的余弦相似度,并选择余弦相似度最高的哪一类作为预测结果。
第三阶段:半监督微调训练阶段。在这一阶段采用新的带有标签的数据集和与类别相同但不带有标签的数据集,共同对网络进行微调训练。本文采用MixMatch的方式进行半监督训练,定义表示个带有标签的样本,表示个不带有标签的样本。首先对每个无标签的样本进行数据扩充(应该采用的是常规的翻转,放缩等形式)得到个合成样本,然后用第二阶段训练得到的分类器对每个无标签样本进行预测,并取个合成样本的平均值作为预测结果
锐化操作(sharpen operation)用于进一步增强预测结果
其中,这样就得到了无标签样本对应的标签信息了。将数据集级联后,再将顺序打乱,得到新的混合数据集,然后将其分为以下两个集合
其中混合操作MixUP计算过程如下
式中,是从Beta分布中随机生成的。
实现过程
网络结构
特征提取网络采用宽阔的残差网络WRN-28-10。
损失函数
损失函数计算过程如下
其中
训练策略
本文的训练过程如下
创新点
- 采用基于迁移学习的半监督训练方法实现小样本学习任务
- 采用Weight Imprinting的方式进行分类器权重生成,采用MixUp方式进行半监督训练
算法评价
与之前研究较多的采用元学习的小样本学习方法不同,本文沿用了更为传统的迁移学习思想,并结合半监督学习方式,证明了迁移学习还是能够在小样本场景下取得较好的效果的。但本文核心创新点并不多,有一种拼凑的感觉。无论是Weight Imprinting分类器权重生成还是MixUp半监督训练方法都是借鉴了别人的方案。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。