CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation

题目:用于无监督域自适应的可迁移原型网络
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
作者:潘滢炜个人主页 京东AI研究院
出处:CVPR2019 Oral arxiv1904.11227v1
源码:暂无

动机:

域自适应的关键是如何减小source domain和target domain的特征的差异。现有的方法直接对齐source domain和target domain的整体分布,忽略了细粒度的domain gap。本文提出在类别和样本的粒度上减小域差异。

核心点/摘要:

本文提出了基于原型网络的无监督域自适应算法——用于自适应的可转移原型网络(TPN),使得源域和目标域中每个类的原型在嵌入空间上接近,并且每个类的原型在源域和目标域的数据上预测的得分分布相似。
技术上,TPN先将每个目标域的样本与源域中最近的原型进行匹配,并为样本分配一个伪标签。然后分别在仅源数据、仅目标数据和源域+目标域数据上计算每个类的原型。TPN是通过联合最小化原型在三种数据类型上的距离以及每对原型输出的分数分布的KL散度来进行端到端训练的。
在跨MNIST、USPS和SVHN数据集的迁移上进行了广泛的实验,当与最先进的方法比较时,显示了优越的结果。更值得注意的是,我们在VisDA 2017数据集上获得了单模型的80.4%的准确率。

数据集:

MNIST、USPS【手写数字图像】、SVHN【在谷歌街景的真实世界的房子号码图像】、VisDA 2017【合成到真实对象分类数据集】

主要贡献:

把原型网络应用到无监督域自适应的场景——通过减少多粒度(即类级和样本级)域差异来对齐源域和目标域的分布,使原型网络能够跨域迁移。

相关工作:

1、UDA

  • 对齐源域和目标域之间的数据分布,或者通过相关距离或最大平均误差等度量方式来最小化domain shifts来构建域之间的不变性
    参考文献:[31]、[15]、[17]、[16]、[27]、[34]
  • 学习域鉴别器
    域鉴别器被设计用来预测每个输入样本的域(源/目标),并以一种对抗的方式进行训练,类似于GANs,以学习域不变表示。
    参考文献:[4,14,29,30,35]、Coupled GANs[13] 、[32]

2、原型网络
假设存在一个嵌入空间,在这个空间中,每个类中的样本的投影围绕一个原型(或质心)聚类。然后通过计算每个类在嵌入空间中相对于原型表示的距离来进行分类。
参考文献:[26]

模型框架:

具体步骤:
目标是学习内嵌函数 f ( x i ; θ ) : x i → R m f(x_i;θ):x_i→R^m f(xi;θ):xiRm(把输入样本转换到嵌入空间),它形式化地减少了共享特征空间中的域移位,并使学习依赖于源域和目标域数据的可转移表示和分类器成为可能。

  1. 在源域数据上训练分类器(原型 μ c s μ_c^s μcs),并直接为目标域数据预测一个伪标签(得分>0.6)
  2. 构建在target-only和source-target数据上的分类器(原型 μ c t 、 μ c s t μ_c^t 、μ_c^{st} μctμcst
  3. 最小化来自不同域的同一类原型( μ c s 、 μ c t 、 μ c s t μ_c^s 、μ_c^t 、μ_c^{st} μcsμctμcst)之间的距离 (class-level) Loss1
  4. 样本通过 μ c s 、 μ c t 、 μ c s t μ_c^s 、μ_c^t 、μ_c^{st} μcsμctμcst 分别得到分数分布: P i s 、 P i t 、 P i s t P_i^s 、P_i^t 、P_i^{st} PisPitPist,最小化所有样本预测分数分布的KL散度。(sample-level) Loss2
  5. 最小化整体损失:Loss1 + Loss2 + 源域数据的分类损失
  6. 交替更新2-5步。
  7. 测试时,可以采用 μ c s 、 μ c t 、 μ c s t μ_c^s 、μ_c^t 、μ_c^{st} μcsμctμcst任一种原型对目标域数据进行分类

(1)源域图像分类:【步骤1】

每个类的原型定义:
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation S c S_c Sc:来自类c的样本集合

query sample x i x_i xi属于类别c的概率: P i ∈ R c P_i∈R^c PiRc
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptationd(·)表示样本到原型的距离函数(如Euclidean距离)

训练目标:最小化把样本 x i x_i xi分配正确的类标签c的负对数似然概率
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation

本文是在原型网络的框架下探索general-purpose adaptation和task-specific domain adaptation:

(2)general-purpose adaptation:【步骤2、3】 class-level

  • 计算分类器:
    source-only data ( S s ) (S^s) (Ss)、 target-only data ( S t ) (S^t) (St) 、source-target data ( S s t ) (S^st) (Sst)的分类器:(原型 μ c s 、 μ c t 、 μ c s t μ_c^s 、μ_c^t 、μ_c^{st} μcsμctμcst
    CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation

  • 度量域间的类级域差异:
    计算在来自不同域的同一类原型之间的reproducing kernel Hilbert space (RKHS) 距离。【基本思想是,如果源域和目标域的数据分布相同,那么在不同域上实现的相同类的原型是相同的。】
    Class-level discrepancy loss:
    CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
    通过最小化这一项,每个领域中计算的每个类的原型将在嵌入空间中紧密接近,从而导致在一般情况下在各个领域之间的不变表示分布。

  • 和MMD的关系
    MMD是一个核双样本检验,它通过将源数据映射到一个再生核希尔伯特空间来测量源数据和目标数据之间的分布差异。
    MMD的计算如下:
    CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
    MMD式子可以解释为RKHS中各个领域的整体原型;类级域差异被计算为来自不同域的每个类的原型之间的RKHS距离。

(3)task-specific domain adaptation:【步骤4】 sample-level

当目标域和源域的分布很好地对齐时,模型应该对目标域数据进行正确的分类。在原型网络中,相当于在不同域中适应原型产生的分数分布。

query sample x i x_i xi 通过三个分类器分别得到三个分数分布: P i s 、 P i t 、 P i s t P_i^s 、P_i^t 、P_i^{st} PisPitPist

为了衡量样本水平的区域差异,我们利用KL散度来评估不同域得分分布之间的两两距离。

在源域和目标域样本上的sample-level discrepancy loss:
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
(4)整体目标
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
(5)理论分析

实验结果:

Digits Image Transfer:M-> U, U-> M and S-> M

Synthetic-to-Real Image Transfer:源域:训练数据(synthetic images)、目标域: validation data (cropped COCO images)

度量指标:所有类别的平均准确率

CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation
S->M的难度较大

实验分析:

t-SNE可视化、Confusion Matrix可视化
CVPR2019 Transferrable Prototypical Networks for Unsupervised Domain AdaptationCVPR2019 Transferrable Prototypical Networks for Unsupervised Domain Adaptation

个人总结:

从实验上来看减小 T P N g e n TPN_{gen} TPNgen方法更有效。

每次迭代更新时都需要重新预测出所有目标域数据的伪标签,计算代价大。

这篇文章的参考文献都可以读一读。