论文阅读笔记《Incremental Few-Shot Object Detection》
核心思想
本文提出一种可以进行增量式学习的小样本目标检测算法(ONCE),相较于其他的小样本目标检测算法而言,本文的优势在于,当在基础数据集上训练完成后,可以直接使用新的小样本数据集进行推断,而且这个过程不会忘记基础数据集中的内容。本文的主体网络采用了CentreNet(伯克利的这篇《Objects as points》)的结构,使用该网络的原因有两点:1.该算法是一种高效的单阶段目标检测算法,在准确率与速度之间达到了较好的平衡;2.该算法是一种“类别专用化”(class-specific)的模型,能够很容易地采用插入式的方式引入新的类别。整个网络的结构如下图所示
首先CentreNet的思路是先对图像进行的特征提取,得到特征图,共包含个通道,对应目标检测任务中的个类别;然后进入目标定位器,输出热力图预测结果,预测结果共包含个通道,个通道对应个类别,剩余的四个通道分别对应边界框中心点偏移量的预测结果和边界框尺寸的预测结果。第个通道的热力图中每个像素点的值表示该点是第类目标物体中心点的概率,如果某个点的值高于与他相邻的8个点的值,那这个点就称为峰值点,可以看作是目标物体的中心,再结合其他四个通道预测的偏移量和边界框尺寸,就能得到对应的边界框了。
本文的算法基本沿用了CentreNet的思想,首先利用特征提取网络,得到特征图,然后利用目标定位器,得到热力图
式中表示卷积核的权重,表示卷积运算,表示类别序号,与CentreNet中通过端到端训练的方式获得权重参数的方式不同,本文利用一个编码生成器,输出该权重参数
而该权重参数是与类别相关的,也就是每个类别都有对应的用于定位器中的计算。最后再按照CentreNet相同的方式预测类别与边界框位置。
实现过程
网络结构
特征提取网络采用ResNet作为编码器,用反卷积网络作为解码器;编码生成器采用与特征提取网络相同的编码器结构,但不带有解码器,最后通过平均池化的方式输出权重。
训练策略
整个训练过程分为两个阶段:第一阶段是在基础数据集上,按照CentreNet的方式进行训练,获得特征提取网络的参数;第二阶段,在小样本数据集上,采用元训练的方式,获取编码生成器网络的参数。如图中所示,查询集图片中包含猫和狗两类物体,对应的支持集图像就应该包含猫和狗两类图片,利用支持集图像获得每个类别对应的权重参数,对查询集图像经特征提取得到特征图进行卷积,得到预测结果。因为经过元训练之后,编码生成器具备了根据少量样本生成对应权重的能力,而且在基础数据集上训练得到的目标定位器权重并没有丢失,因此本算法可以实现增量式的小样本学习。
算法推广
本文算法还可用于服装关键点检测(fashion landmark detection)任务
创新点
- 引入了CentreNet的结构和思想,并将其拆分为特征提取和目标定位两个部分
- 采用元学习的方式训练编码生成器,针对每个类别的图像输出对应的权重,并利用该权重完成测试图像的目标检测工作
算法评价
本文其实是在CentreNet的基础进行改进的结果,利用一个编码生成器生成目标定位器中的卷积核权重,取代原本通过迭代训练更新权重的方式。这是一种快权重算法,权重值不是利用SGD进行更新,而是由另一个网络直接输出。这在小样本学习中有着广泛的应用,作者将其与优秀的目标检测算法CentreNet相结合就得到了一种小样本目标检测算法。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。