[cvpr2017]Mind the Class Weight Bias: Weighted Maximum Mean Discrepancy for Unsupervised DA
introduce
- 本文研究的范围仅限于UDA(unsupervised domain adaptation)
- 作者认为使用MMD(maximum mean discrepancy)来衡量source domain和target domain之间的差异不够准确,这是因为没有考虑class prior distributions(类的先验分布,就是某个类在整个domain中所占的比重),为了解决这个问题,作者提出了一个叫做weighted MMD(WMMD)的模型。(Despite the great success
achieved, existing ones(MMD-based methods) generally ignore the changes of
class prior distributions, dubbed by class weight bias.) - 对基于MMD的域适应方法来说,对class weight bias(各个类中的样本数所占的比重应该就是class weight)的忽略可能导致性能的下降(For MMD-based methods, the ignorance of class weight
bias can deteriorate the domain adaptation performance) - 如下图:
- MMD的限制在于当source domain和taregt domian中的class weight不同(或者如图中所示,更严重地target domain缺少source domain中的类)时,使用MMD会导致错误的分类(MMD会使得target domain的class weight强行与source domain一致)。
- 然而问题是,target domain是没有label的,所以target domain的class weight是未知的
- 因此作者首先引入了class-specific auxiliary
weights(类特定辅助权重?)来对source domain进行reweight,使得source domain的class weight和target domain的完全一致。 - 通过最小化weighted MMD(WMMD)的目标函数来共同优化auxiliary
weights的估计量和模型参数学习。 - 作者使用一个叫做 classification EM (CEM)的方案来估计他。
- 在E步骤和C步骤中,计算类后验概率(target domain的class weight的后验概率),将伪标签(pesudo label)分配给target domain的样本,并估计auxiliary
weights。 - 在M步骤当中,通过最小化目标函数的损失来更新参数(普通的机器学习训练过程)。
maximum mean discrepancy
- (MMD基础理论部分,数学用语很多,不想翻译了,我就直接贴截图了)
Weighted Maximum Mean Discrepancy
-
ps(xs) 、pt(xt) :source domain和target domain的概率分布密度 - 以上二者都可以用类的条件分布的混合来表示:
- 其中
wsc=ps(ys=c) 和wsc=ps(yt=c) 就是前文所提到的class prior probility(class weight)。
- 其中
- MMD比较的是
ps(xs) 和pt(xt) ,也就是概率密度,但作者认为比较source domain和target domain的条件概率密度ps(xs|ys=c) 和pt(xt|yt=c) 更为有效(这个是判别式模型(discrinimative model)学习的目标:类的后验概率) - 作者建议利用reference source distribution
ps,α(xs) 来计算source domain和target domain之间的差异(discrepancy) 要求
ps,α(xs) 和target domain有一样的class weight(wsc=ps(ys=c) )并且保留source domain的条件概率密度(ps(xs|ys=c) ),所以:-
weight MMD的基本形式:(利用
ps,α(xs) 和pt(xt) 来计算):- weight MMD的线性复杂度近似(为了速度和SGD,详细的理论见上面MMD后半部分)
- weight MMD的线性复杂度近似(为了速度和SGD,详细的理论见上面MMD后半部分)
Weighted Domain Adaptation Network
- 作者认为WMMD正则化层需要加载CNN的高层,因为dataset bias会在高层增加:
- WDAN(Weighted Domain Adaptation Network)模型:
- 优化WDAN的过程:
- E-step:估计target domain的类
{xtj}Nj=1 的后验概率(class posterior probility) - C-step:基于E-step中计算出的最大的class posterior probility,将伪标签(pseudo-label)
{yˆNj=1} 赋给每个xtj ,并且估计辅助权重(auxiliary weights)α - M-step:在给定的
α 和{yˆNj=1} 下更新模型参数W :
- E-step:估计target domain的类