Few-shot Learning with Meta Metric Learners

一、介绍

现有的基于元学习、度量学习的小样本学习方法在处理diverse domains和various classes上存在局限。元学习训练一个meta learner预测具有相同结构,但针对不同任务网络的权重。度量学习针对不同任务学习一个不随任务改变,适应所有任务的度量。当任务间差异较大时,度量学习将会失败,学不到这样的度量。作者提出了一个元度量学习的方法,利用度量学习的匹配网络作为base learner来处理不同任务类别数不同的问题,使用meta learner来给base learner选择好的参数以及梯度下降(相当于可学习的学习率)。

二、背景知识

小样本学习是指对仅有少量样本的新的类别学习一个分类器。两个关键的想法是数据聚合和知识共享(data aggregation and knowledge sharing)。虽然每个小样本学习任务缺乏足够的训练数据,但是所有任务的集合能够给模型训练提供足够的数据。因此,对于新的小样本学习任务,可以从之前学习过的任务中获利。

小样本学习的具体任务:k-shot, N-way场景:每个小样本任务有N类,每类有k个样本。作者指出,不同任务有相同数目的类别实际上是不太实际的。之前基于元学习的小样本方法都不能解决这个问题。而能够解决这个问题的度量学习方法试图学习对任务不变的度量,当任务差异较大时,将会失败。作者的方法很好的解决了这个问题。

2.1 匹配网络(类似knn)

匹配网络是最近提出的在cv领域处理小样本问题的模型。包含两个共享权重的神经网络f和g,以及一个记忆矩阵。f和g把输入映射为N维向量,然后和记忆矩阵各向量计算余弦距离,然后再进行一个softmax(对应类别的样本距离趋近1,其他类别样本距离趋近0),表示和记忆矩阵中样本属于同一类的概率。和记忆矩阵样本的标签(1,2,3,...,k)加权求和后作为预测标签。

训练:从一个正常数据集中取出N个id,然后每个id随机采样k个样本组成support set (经过g提取特征后作为记忆矩阵)。再采样一些样本组成训练集合B。然后优化下式的目标函数:

Few-shot Learning with Meta Metric Learners

2.2 Lstm based meta learning(meta learning, also known as learning to learn)

标准的优化算法:

Few-shot Learning with Meta Metric Learners

LSTM迭代公式:

Few-shot Learning with Meta Metric Learners

 可以看到,二者非常相似。把C_t作为网络权重参数,C_t+1~作为网络的梯度,可以用LSTM来知道网络进行更新(学习)。因此LSTM可以作为meta learner。

三、作者的方法

 

Few-shot Learning with Meta Metric Learners

如上图,作者的模型包括两个部分。meta learner是一个LSTM,在每个任务中给base learner发现好的参数和梯度下降。base leanrner是一个匹配网络,使用LSTM提供的参数进行参数设置。因此这个模型能够处理不同任务的不同数目的类别(匹配网络的作用)而且能够产生针对不同任务的度量(因为meta-learner在给不同的任务实例后进行相应的权重预测)。

Few-shot Learning with Meta Metric Learners

把数据集分成D_meta-train、D_meta-test、D_meta-val三部分,其中D_meta-val用来调整网络超参。利用D_meta-train训练meta learner,具体过程是:从 D_meta-train采样train、test集合用于训练matching网络,train作为support set,test从中采样batch计算loss,得到梯度和原始matching网络参数送入LSTM,使用LSTM网络输出作为matching网络新的参数。matching网络更新完毕后,使用train和test所有样本计算loss,更新LSTM网络。

meta-learner训练完成后,使用meta-learner指导base learner的训练。将D_meta-test样本选取一部分作为matching网络的support set(至少N个),剩余部分用来训练matching网络。作者指出,这就要求每类至少有两个样本,和one shot的设置相违背。针对这种情况,作者设计了一种通过D_aux训练,测试时使用D_meta-test作为matching网络的support set的方法。即使用不同任务的数据训练多个matching网络,选择matching网络在D_meta-test样本准确率最高的网络的数据组成D_aux。