Deep Cocktail Network

Deep Cocktail Network

Deep Cocktail Network

1.motivation

domain adaptation是由于获得大量的标注是一件耗时的工作,希望能通过利用已经有标注的source数据集来提升网络在没有标注的target数据集上的表现.本文的出发点是希望使用多个source数据集来进行domain adaptation。multi-source domain adaptation作者认为主要存在两个问题**1)domain shift:**包括源域和目标域,以及不同目标域之间;2) category-shift: 目标域的标签不一定完全一样。受目标域的概率分布可以由多个源域概率分布加权来表示,作者提出了deep cocktail network(DCTN)来解决上述问题。

Deep Cocktail Network

2.method

Deep Cocktail Network

(1) overview

网络结构主要包含三个部分 Feature Extractor(共享参数)将图像空间映射到特征空间;Domain Discrinator,一共有N个,N为源域的个数,即每一个源域都要和穆标宇训练一个分类器;Category Classifier,也是N个,每一个源域有一个category分类器。

以及最后的target classification operator,利用前面N个category 分类器来的分类loss进行weighted combination.

(2)Feature extractor

F为特征提取,将所有的图片映射到特征空间F(x), x表示输入图片

(3)Domain Discriminator

{Dsj}j=1N\{D_{s_j}\}_{j=1}^N表示N个discriminator,DsjD_{s_j}用来区分F(x)是来自源域还是目标域

Scf(xt;F,Dsj)=log(1Dsj(F(xt)))+αsjS_{cf}(x^t;F,D_{s_j}) =-log(1-D_{s_j}(F(x^t))) + \alpha_{s_j}

其中αsj\alpha_{s_j}是source-specific concentration constant表示第j个discriminator在源域XsjX_{sj}的平均值。Scf(xt;F,Dsj)S_{cf}(x^t;F, D_{s_j})是target-source perplexity scores 就是分类的loss,用作后续的加权平均,α\alpha 作用是相当于进行了normalize.

(4)category classifier

(5)Target classification operator
Confidence(cxt):=cCsjScf(xt;F,Dsj)cCskScf(xt;F,Dsk)Csj(cF(xt)) Confidence(c|x^t):=\sum_{c\in C_{s_j}} \frac{S_{cf}(x^t;F,D_{s_j})}{\sum_{c\in C_{s_k}}S_{cf}(x^t;F, D_{s_k})}C_{s_j}(c|F(x^t))

3. training

minFmaxDV(F,D;Cˉ)=Ladv(F,D)+Lcls(F,Cˉ) min_{F}max_{D} V(F, D;\bar{C}) = L_{adv}(F, D) + L_{cls}(F,\bar{C})

其中
Ladv(F,D)=1NjNExXsj[logDsj(F(x))]+ExtXt[log(1Dsj(F(xt))] L_{adv}(F, D) = \frac{1}{N}\sum_j^N E_{x\sim X_{s_j}}[log D_{s_j}(F(x))]+E_{x^t\sim X_t}[log(1-D_{s_j}(F(x^t))]
让网络学习分类误差最小,discriminator误差最大这样让source domain和target domain在feature space上混淆了,这样domain shift就变小了

**Online hard domain batch mining **

Deep Cocktail Network

一共有N个源域,每个源域sample M个,一共N*M个样本。对N个discriminator中最大的进行更新

Target Discriminative Adaptation

作者认为荣国multi-way adversary,DCTN已经能学习到domain-invariant的特征,但是在target domain的分类能力不行。

为了逼近理想的target分类器,给target domain中的每一个样本打上pseudo labels(就是用之前的网络进行inference),然后在将target domain的数据和source domain的数据联合训练

Deep Cocktail Network

因为没有针对target训练一个分类器,所以将target classification error反传到multi source的category classifier. 具体来说,对target的样本(xt,y^)(x^t, \hat{y}) 对源域中含有y^\hat{y} 类别的计算对应的分类loss,求和。
Deep Cocktail Network

4 experiment

Deep Cocktail Network

  • single best:对每对源域-目标域训练 选择最好的结果
  • source combine:将多个源域合并成单个domain

Deep Cocktail Network

为了对比对category shift的效果设计了两种不同的category模式overlap(source类别之间是有交集的) disjoint(类别之间完全是没有交集的)

Deep Cocktail Network
可以看到DTCN保证了domain的相似性以及类别之间的区分性