Matching Networks for One Shot Learning论文解读

这篇文章在元学习领域笔记重要,之前一直想读,这次正好有机会就把它给刷了。

本篇论文属于小样本学习领域,但是本篇论文中的Matching Networks常被用于与Meta-learning任务中的方法进行比较。这篇论文出自Google DeepMind团队,发表于2016年。

1 Motivation

人类可以可以通过非常少量的样本学习到一个新的概念,比如一个小孩子看完一张长颈鹿的照片之后就认识了长颈鹿这个动物。但是最好的深度学习模型依然需要成百上千的例子来学习到一个新的概念。因此本文就考虑如何通过一个样本就让深度学习模型学会一个新概念。

传统上训练出一个模型需要使用很多样本进行很多次的参数更新,因此作者认为可以使用一个无参数的模型。参考KNN这种度量式的做法,作者将有参数的模型和无参数的模型进行了结合。

2 Contribution

  1. 在模型层面上,作者提出了一个Matching Networks, 将注意力机制和记忆机制引入快速学习任务中。
  2. 在训练流程上,作者训练模型时遵循了一个很简单的规则,即测试和训练条件必须匹配。作者在训练时仅用每个类别中很少的样本进行训练,因为在测试时也使用的是很少的样本。(即训练条件和测试条件匹配)

3 Method

3.1 Model Architecture

Matching Networks for One Shot Learning论文解读
gθg_{\theta}fθf_\theta分别是对训练数据和测试数据的编码函数。Matching Networks可以简洁表示为计算一个无标签样本的标签为y^\hat{y}的概率,这个计算方法跟KNN很像,相当于是加权后的KNN:
P(y^x^,S)=i=1ka(x^,xi)yi P(\hat{y}|\hat{x},S) = \sum^{k}_{i=1}a(\hat{x},x_i)y_i
其中xi,yix_i,y_i是输入的支撑集(support set)中的样本S={(xi,yi)}i=1kS = \{(x_i,y_i)\}^k_{i=1}aa类似于注意力机制中的核函数,用来度量x^xi\hat{x},x_i的匹配度。
a(x^,xi)=ec(f(x^),g(xi))j=1kec(f(x^),g(xj)) a(\hat{x},x_i) = \frac{e^{c(f(\hat{x}),g(x_i))}}{\sum^k_{j=1}e^{c(f(\hat{x}),g(x_j))}}
在这里公式ff定义了对测试样本的编码方式,对于Figure 1 中的gθg_{\theta};公式gg定义了对训练样本的编码方式,对应于Figure 1 中的fθf_\theta。这个公式先对f(x^),g(xi)f(\hat{x}),g(x_i)计算了一个余弦距离,然后在做一个softmax归一化。

3.2 Training Function g

gg是一个BiLSTM,它的输入是xix_i和支撑集SS
g(xi,S)=hi+hi+g(xi) g(x_i, S) = \overrightarrow{h_i} + \overleftarrow{h_i}+ g'(x_i)

Matching Networks for One Shot Learning论文解读
其中$g'(x_i)$是一个神经网络,比如VGG或者Inception。

3.3 Test Function f

ff是一个迭代了K步的 LSTM,它的输出是LSTM最后输出的隐状态hh。即f(x^,S)=hkf(\hat{x},S)=h_k,其中hkh_k由(3)式决定:

Matching Networks for One Shot Learning论文解读
其中,ff'是一个embedding函数,比如一个CNN。

3.4 Training procedure

给定一个有k个样本的支撑集S={(xi,yi)}i=1kS = \{(x_i,y_i)\}^k_{i=1},对测试样本 x^\hat{x}分类为 CS(x^)C_S(\hat{x})。定义SCS(x^)S \rightarrow C_S(\hat{x}) 这一映射为P(y^x^,S)=i=1ka(x^,xi)yiP(\hat{y}|\hat{x},S) = \sum^{k}_{i=1}a(\hat{x},x_i)y_i

在测试过程中,给定一个新的支撑集SS',我们可以用之前学到的模型对每个测试样本x^\hat{x}得到他们可能的label y^\hat{y}

4 Experimental results

4.1 Omniglot dataset

Matching Networks for One Shot Learning论文解读
Omniglot 数据集包含来自 50个不同国家的字母表的 1623 个不同手写字符。每一个字符都是由 20个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。

Matching Networks for One Shot Learning论文解读

4.2 ImageNet dataset

Matching Networks for One Shot Learning论文解读
作者一共在ImageNet数据集上做了三组实验:

  • In the rand setup:在训练集中随机去除了118个label的样本,并将这118个标签的样本用于之后的测试。
  • For the dogs setup:移除了所有属于狗这一大类的样本(一共118个子类),之后用这118个狗的子类样本做测试.
  • 作者还新定义了一个数据集 miniImageNet —— 一共有100个类别,每个类有600个样本。其中80个类用于训练20个类用于测试。

实验结果为:
Matching Networks for One Shot Learning论文解读

4.3 Penn Treebank dataset

这个Pennn Treebank数据集来自华尔街日报。作者利用数据集做了一个one-shot Language Model的实验,利用上下文来预测中间词。作者通过query集中与support集中两个句子的比较来确定中间词。如下图所示。
Matching Networks for One Shot Learning论文解读
但是实验结果并不理想。

The LSTM language model oracle achieved an upper bound of 72.8% accuracy on the test set. Matching Networks with a simple encoding model achieve 32.4%, 36.1%, 38.2% accuracy on the task with k = 1, 2, 3 examples in the set, respectively. Future work should explore combining parametric models such as an LSTM-LM with non-parametric components such as the Matching Networks explored here.

作者只是在这里提供了一个Matching Network用于语言模型的思路。

5 Conclusions

作者在本文中引入了Matching Networks并在小样本学习任务上取得了很不错的效果。作者还在ImageNert数据集上定义了一个one-shot任务,此后ImageNet数据集成为了Meta-Learning的标准数据集。同时作者启发性地将one-shot任务应用于语言模型,为后续研究提供了一个很好的思路。

References

[1] 知乎文章
[2] github 文章
[3] 论文原文