Deep Cocktail Network
Deep Cocktail Network
1.motivation
domain adaptation是由于获得大量的标注是一件耗时的工作,希望能通过利用已经有标注的source数据集来提升网络在没有标注的target数据集上的表现.本文的出发点是希望使用多个source数据集来进行domain adaptation。multi-source domain adaptation作者认为主要存在两个问题**1)domain shift:**包括源域和目标域,以及不同目标域之间;2) category-shift: 目标域的标签不一定完全一样。受目标域的概率分布可以由多个源域概率分布加权来表示,作者提出了deep cocktail network(DCTN)来解决上述问题。
2.method
(1) overview
网络结构主要包含三个部分 Feature Extractor(共享参数)将图像空间映射到特征空间;Domain Discrinator,一共有N个,N为源域的个数,即每一个源域都要和穆标宇训练一个分类器;Category Classifier,也是N个,每一个源域有一个category分类器。
以及最后的target classification operator,利用前面N个category 分类器来的分类loss进行weighted combination.
(2)Feature extractor
F为特征提取,将所有的图片映射到特征空间F(x), x表示输入图片
(3)Domain Discriminator
表示N个discriminator,用来区分F(x)是来自源域还是目标域
其中是source-specific concentration constant表示第j个discriminator在源域的平均值。是target-source perplexity scores 就是分类的loss,用作后续的加权平均, 作用是相当于进行了normalize.
(4)category classifier
(5)Target classification operator
3. training
其中
让网络学习分类误差最小,discriminator误差最大这样让source domain和target domain在feature space上混淆了,这样domain shift就变小了
**Online hard domain batch mining **
一共有N个源域,每个源域sample M个,一共N*M个样本。对N个discriminator中最大的进行更新
Target Discriminative Adaptation
作者认为荣国multi-way adversary,DCTN已经能学习到domain-invariant的特征,但是在target domain的分类能力不行。
为了逼近理想的target分类器,给target domain中的每一个样本打上pseudo labels(就是用之前的网络进行inference),然后在将target domain的数据和source domain的数据联合训练
因为没有针对target训练一个分类器,所以将target classification error反传到multi source的category classifier. 具体来说,对target的样本 对源域中含有 类别的计算对应的分类loss,求和。
4 experiment
- single best:对每对源域-目标域训练 选择最好的结果
- source combine:将多个源域合并成单个domain
为了对比对category shift的效果设计了两种不同的category模式overlap(source类别之间是有交集的) disjoint(类别之间完全是没有交集的)
可以看到DTCN保证了domain的相似性以及类别之间的区分性